versae commited on
Commit
3a73ca8
1 Parent(s): f624ac4

mzjvp6ho: saving weights and logs of step 5k

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. events.out.tfevents.1660117225.t1v-n-eedfb410-w-0.55204.0.v2 +3 -0
  2. events.out.tfevents.1660130897.t1v-n-eedfb410-w-0.8420.0.v2 +3 -0
  3. events.out.tfevents.1660143983.t1v-n-eedfb410-w-0.3332902.0.v2 +3 -0
  4. events.out.tfevents.1660145355.t1v-n-eedfb410-w-0.2349240.0.v2 +3 -0
  5. events.out.tfevents.1660206880.t1v-n-eedfb410-w-0.1479163.0.v2 +3 -0
  6. events.out.tfevents.1660208728.t1v-n-eedfb410-w-0.503538.0.v2 +3 -0
  7. events.out.tfevents.1660218137.t1v-n-eedfb410-w-0.2916397.0.v2 +3 -0
  8. flax_model.msgpack +1 -1
  9. run.recover.sh +1 -2
  10. run.sh +5 -7
  11. run_flax_speech_recognition_ctc.py +5 -5
  12. special_tokens_map.json +98 -0
  13. wandb/debug-internal.log +1 -1
  14. wandb/debug.log +1 -1
  15. wandb/latest-run +1 -1
  16. wandb/run-20220810_073735-23avj35z/files/code/run_flax_speech_recognition_ctc.py +1631 -0
  17. wandb/run-20220810_073735-23avj35z/files/config.yaml +33 -0
  18. wandb/run-20220810_073735-23avj35z/files/diff.patch +27 -0
  19. wandb/run-20220810_073735-23avj35z/files/output.log +3 -0
  20. wandb/run-20220810_073735-23avj35z/files/requirements.txt +158 -0
  21. wandb/run-20220810_073735-23avj35z/files/wandb-metadata.json +70 -0
  22. wandb/run-20220810_073735-23avj35z/files/wandb-summary.json +1 -0
  23. wandb/run-20220810_073735-23avj35z/logs/debug-internal.log +3 -0
  24. wandb/run-20220810_073735-23avj35z/logs/debug.log +3 -0
  25. wandb/run-20220810_073735-23avj35z/run-23avj35z.wandb +3 -0
  26. wandb/run-20220810_111559-290849gb/files/code/run_flax_speech_recognition_ctc.py +1631 -0
  27. wandb/run-20220810_111559-290849gb/files/config.yaml +33 -0
  28. wandb/run-20220810_111559-290849gb/files/diff.patch +52 -0
  29. wandb/run-20220810_111559-290849gb/files/output.log +3 -0
  30. wandb/run-20220810_111559-290849gb/files/requirements.txt +158 -0
  31. wandb/run-20220810_111559-290849gb/files/wandb-metadata.json +70 -0
  32. wandb/run-20220810_111559-290849gb/files/wandb-summary.json +1 -0
  33. wandb/run-20220810_111559-290849gb/logs/debug-internal.log +3 -0
  34. wandb/run-20220810_111559-290849gb/logs/debug.log +3 -0
  35. wandb/run-20220810_111559-290849gb/run-290849gb.wandb +3 -0
  36. wandb/run-20220810_145446-1k92sv35/files/code/run_flax_speech_recognition_ctc.py +1632 -0
  37. wandb/run-20220810_145446-1k92sv35/files/config.yaml +33 -0
  38. wandb/run-20220810_145446-1k92sv35/files/diff.patch +132 -0
  39. wandb/run-20220810_145446-1k92sv35/files/output.log +3 -0
  40. wandb/run-20220810_145446-1k92sv35/files/requirements.txt +158 -0
  41. wandb/run-20220810_145446-1k92sv35/files/wandb-metadata.json +69 -0
  42. wandb/run-20220810_145446-1k92sv35/files/wandb-summary.json +1 -0
  43. wandb/run-20220810_145446-1k92sv35/logs/debug-internal.log +3 -0
  44. wandb/run-20220810_145446-1k92sv35/logs/debug.log +3 -0
  45. wandb/run-20220810_145446-1k92sv35/run-1k92sv35.wandb +3 -0
  46. wandb/run-20220810_151736-2jo5la5b/files/code/run_flax_speech_recognition_ctc.py +1632 -0
  47. wandb/run-20220810_151736-2jo5la5b/files/config.yaml +33 -0
  48. wandb/run-20220810_151736-2jo5la5b/files/diff.patch +144 -0
  49. wandb/run-20220810_151736-2jo5la5b/files/output.log +3 -0
  50. wandb/run-20220810_151736-2jo5la5b/files/requirements.txt +158 -0
events.out.tfevents.1660117225.t1v-n-eedfb410-w-0.55204.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e2942e6466c560f7eac1111e0c5b04e5ba30bc241b0d7408d895e2e7cad769c
3
+ size 40
events.out.tfevents.1660130897.t1v-n-eedfb410-w-0.8420.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d3db5592c5e247b36ba97b636878b108457cd33220828a88069711c1ae23838
3
+ size 40
events.out.tfevents.1660143983.t1v-n-eedfb410-w-0.3332902.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f299990d7dfa6f5112ae7da90bcf6a5217143514806dc1b21b14594cfc58a389
3
+ size 40
events.out.tfevents.1660145355.t1v-n-eedfb410-w-0.2349240.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:646540f39538eecbf2b7e4d42ca372657d297ba5bd1f4206725e039acbea46a4
3
+ size 40
events.out.tfevents.1660206880.t1v-n-eedfb410-w-0.1479163.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:26e0c8ec22f6f03deb1c4c2b33e27affb7898c696044d480d3ed8861cbe6ad58
3
+ size 40
events.out.tfevents.1660208728.t1v-n-eedfb410-w-0.503538.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:09793df210508f083a2bc11c7343fd21bdd95cab1d58deef35d0921cd281ffa8
3
+ size 40
events.out.tfevents.1660218137.t1v-n-eedfb410-w-0.2916397.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:95289ecfa3686050d8cfdb3c04ca3ef3ff1d6fe0c12a59ee107d52e0ebdb0d29
3
+ size 40
flax_model.msgpack CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d6e7fe76ddde6be27c0129735dc3bb50191fff23d8f87075f1335970abf06211
3
  size 3850218852
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2824056b9c2f157ff862c17877b5aa4a77f0f6107345973495c02df3828b7469
3
  size 3850218852
run.recover.sh CHANGED
@@ -11,9 +11,8 @@ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_c
11
  --per_device_train_batch_size="2" \
12
  --per_device_eval_batch_size="2" \
13
  --gradient_accumulation_steps="1" \
14
- --precision="full_mixed" \
15
  --matmul_precision="bfloat16" \
16
- --multisteps \
17
  --learning_rate="6.394633237505332e-05" \
18
  --skip_steps="275000" \
19
  --warmup_steps="2000" \
 
11
  --per_device_train_batch_size="2" \
12
  --per_device_eval_batch_size="2" \
13
  --gradient_accumulation_steps="1" \
14
+ --precision="half_mixed" \
15
  --matmul_precision="bfloat16" \
 
16
  --learning_rate="6.394633237505332e-05" \
17
  --skip_steps="275000" \
18
  --warmup_steps="2000" \
run.sh CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_ctc.py \
2
  --model_name_or_path="facebook/wav2vec2-xls-r-1b" \
3
  --hub_model_id="NbAiLab/wav2vec2-1b-npsc-nst-tpu" \
@@ -11,7 +14,7 @@ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_c
11
  --precision="full_mixed" \
12
  --matmul_precision="bfloat16" \
13
  --multisteps \
14
- --learning_rate="1e-4" \
15
  --warmup_steps="2000" \
16
  --length_column_name="input_length" \
17
  --evaluation_strategy="steps" \
@@ -32,7 +35,7 @@ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_c
32
  --mask_feature_length="64" \
33
  --gradient_checkpointing \
34
  --min_duration_in_seconds="0.5" \
35
- --max_duration_in_seconds="30.0" \
36
  --use_auth_token \
37
  --seed="42" \
38
  --group_by_length \
@@ -40,10 +43,5 @@ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_c
40
  --push_to_hub \
41
  --preprocessing_num_workers="32" \
42
  --ctc_zero_infinity \
43
- --do_lower_case \
44
  --wandb_project="wav2vec2" \
45
  --wandb_name="wav2vec2-1b-npsc-nst-tpu" \
46
- --remove_punctuation
47
-
48
-
49
- # --fp16
 
1
+ # See https://github.com/sanchit-gandhi/seq2seq-speech/issues/23#issuecomment-1122183173: do_lower_case should only be set to True for the tokenizer if the tokenizer has upper case letters in the vocab
2
+ # Let's also not add extra remove_punctuation
3
+ # And limit max duration to 25 seconds
4
  WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_ctc.py \
5
  --model_name_or_path="facebook/wav2vec2-xls-r-1b" \
6
  --hub_model_id="NbAiLab/wav2vec2-1b-npsc-nst-tpu" \
 
14
  --precision="full_mixed" \
15
  --matmul_precision="bfloat16" \
16
  --multisteps \
17
+ --learning_rate="2e-5" \
18
  --warmup_steps="2000" \
19
  --length_column_name="input_length" \
20
  --evaluation_strategy="steps" \
 
35
  --mask_feature_length="64" \
36
  --gradient_checkpointing \
37
  --min_duration_in_seconds="0.5" \
38
+ --max_duration_in_seconds="25.0" \
39
  --use_auth_token \
40
  --seed="42" \
41
  --group_by_length \
 
43
  --push_to_hub \
44
  --preprocessing_num_workers="32" \
45
  --ctc_zero_infinity \
 
46
  --wandb_project="wav2vec2" \
47
  --wandb_name="wav2vec2-1b-npsc-nst-tpu" \
 
 
 
 
run_flax_speech_recognition_ctc.py CHANGED
@@ -415,12 +415,12 @@ class MixedPrecisionTrainState(struct.PyTreeNode):
415
  )
416
 
417
  @classmethod
418
- def create(cls, *, apply_fn, params, tx, to_dtype, **kwargs):
419
  """Creates a new instance with `step=0` and initialized `opt_state`."""
420
  # downcast optimizer state to bf16 if mixed-precision training
421
  opt_state = tx.init(to_dtype(params)) if tx is not None else None
422
  return cls(
423
- step=0,
424
  apply_fn=apply_fn,
425
  params=params,
426
  tx=tx,
@@ -1339,6 +1339,7 @@ def main():
1339
 
1340
  # Setup train state
1341
  state = MixedPrecisionTrainState.create(
 
1342
  apply_fn=model.__call__,
1343
  get_attention_mask_fn=model._get_feature_vector_attention_mask,
1344
  params=model.params,
@@ -1517,14 +1518,13 @@ def main():
1517
  if training_args.do_train:
1518
  # ======================== Training ================================
1519
  train_start = time.time()
 
 
1520
 
1521
  if epoch < skip_epochs:
1522
  logger.info(f"Skipping epoch {epoch + 1}")
1523
  continue
1524
 
1525
- # Create sampling rng
1526
- rng, input_rng = jax.random.split(rng)
1527
-
1528
  # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length
1529
  train_samples_idx = get_grouped_indices(vectorized_datasets[data_args.train_split_name], batch_size_per_update, input_rng)
1530
  train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update)
 
415
  )
416
 
417
  @classmethod
418
+ def create(cls, *, apply_fn, params, tx, to_dtype, step=0, **kwargs):
419
  """Creates a new instance with `step=0` and initialized `opt_state`."""
420
  # downcast optimizer state to bf16 if mixed-precision training
421
  opt_state = tx.init(to_dtype(params)) if tx is not None else None
422
  return cls(
423
+ step=step,
424
  apply_fn=apply_fn,
425
  params=params,
426
  tx=tx,
 
1339
 
1340
  # Setup train state
1341
  state = MixedPrecisionTrainState.create(
1342
+ step=data_args.skip_steps,
1343
  apply_fn=model.__call__,
1344
  get_attention_mask_fn=model._get_feature_vector_attention_mask,
1345
  params=model.params,
 
1518
  if training_args.do_train:
1519
  # ======================== Training ================================
1520
  train_start = time.time()
1521
+ # Create sampling rng
1522
+ rng, input_rng = jax.random.split(rng)
1523
 
1524
  if epoch < skip_epochs:
1525
  logger.info(f"Skipping epoch {epoch + 1}")
1526
  continue
1527
 
 
 
 
1528
  # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length
1529
  train_samples_idx = get_grouped_indices(vectorized_datasets[data_args.train_split_name], batch_size_per_update, input_rng)
1530
  train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update)
special_tokens_map.json CHANGED
@@ -399,6 +399,104 @@
399
  "rstrip": false,
400
  "single_word": false
401
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
  {
403
  "content": "</s>",
404
  "lstrip": false,
 
399
  "rstrip": false,
400
  "single_word": false
401
  },
402
+ {
403
+ "content": "</s>",
404
+ "lstrip": false,
405
+ "normalized": true,
406
+ "rstrip": false,
407
+ "single_word": false
408
+ },
409
+ {
410
+ "content": "<s>",
411
+ "lstrip": false,
412
+ "normalized": true,
413
+ "rstrip": false,
414
+ "single_word": false
415
+ },
416
+ {
417
+ "content": "</s>",
418
+ "lstrip": false,
419
+ "normalized": true,
420
+ "rstrip": false,
421
+ "single_word": false
422
+ },
423
+ {
424
+ "content": "<s>",
425
+ "lstrip": false,
426
+ "normalized": true,
427
+ "rstrip": false,
428
+ "single_word": false
429
+ },
430
+ {
431
+ "content": "</s>",
432
+ "lstrip": false,
433
+ "normalized": true,
434
+ "rstrip": false,
435
+ "single_word": false
436
+ },
437
+ {
438
+ "content": "<s>",
439
+ "lstrip": false,
440
+ "normalized": true,
441
+ "rstrip": false,
442
+ "single_word": false
443
+ },
444
+ {
445
+ "content": "</s>",
446
+ "lstrip": false,
447
+ "normalized": true,
448
+ "rstrip": false,
449
+ "single_word": false
450
+ },
451
+ {
452
+ "content": "<s>",
453
+ "lstrip": false,
454
+ "normalized": true,
455
+ "rstrip": false,
456
+ "single_word": false
457
+ },
458
+ {
459
+ "content": "</s>",
460
+ "lstrip": false,
461
+ "normalized": true,
462
+ "rstrip": false,
463
+ "single_word": false
464
+ },
465
+ {
466
+ "content": "<s>",
467
+ "lstrip": false,
468
+ "normalized": true,
469
+ "rstrip": false,
470
+ "single_word": false
471
+ },
472
+ {
473
+ "content": "</s>",
474
+ "lstrip": false,
475
+ "normalized": true,
476
+ "rstrip": false,
477
+ "single_word": false
478
+ },
479
+ {
480
+ "content": "<s>",
481
+ "lstrip": false,
482
+ "normalized": true,
483
+ "rstrip": false,
484
+ "single_word": false
485
+ },
486
+ {
487
+ "content": "</s>",
488
+ "lstrip": false,
489
+ "normalized": true,
490
+ "rstrip": false,
491
+ "single_word": false
492
+ },
493
+ {
494
+ "content": "<s>",
495
+ "lstrip": false,
496
+ "normalized": true,
497
+ "rstrip": false,
498
+ "single_word": false
499
+ },
500
  {
501
  "content": "</s>",
502
  "lstrip": false,
wandb/debug-internal.log CHANGED
@@ -1 +1 @@
1
- run-20220805_230151-2y71vcu4/logs/debug-internal.log
 
1
+ run-20220811_101752-mzjvp6ho/logs/debug-internal.log
wandb/debug.log CHANGED
@@ -1 +1 @@
1
- run-20220805_230151-2y71vcu4/logs/debug.log
 
1
+ run-20220811_101752-mzjvp6ho/logs/debug.log
wandb/latest-run CHANGED
@@ -1 +1 @@
1
- run-20220805_230151-2y71vcu4
 
1
+ run-20220811_101752-mzjvp6ho
wandb/run-20220810_073735-23avj35z/files/code/run_flax_speech_recognition_ctc.py ADDED
@@ -0,0 +1,1631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2022 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the Flax library models for connectionist temporal classification (CTC) speech recognition.
18
+ """
19
+ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
+
21
+ import logging
22
+ import math
23
+ import os
24
+ import re
25
+ import sys
26
+ import time
27
+ from dataclasses import dataclass, field
28
+ from pathlib import Path
29
+ from typing import Any, Callable, Dict, List, Optional, Union
30
+
31
+ import datasets
32
+ import numpy as np
33
+ from datasets import DatasetDict, load_dataset, load_metric
34
+ from tqdm import tqdm
35
+
36
+ import flax
37
+ import jax
38
+ import jax.numpy as jnp
39
+ import optax
40
+ import transformers
41
+ import wandb as wandb
42
+ from flax import core, jax_utils, struct, traverse_util
43
+ from flax.jax_utils import unreplicate, pad_shard_unpad
44
+ from flax.training.common_utils import get_metrics, shard, shard_prng_key
45
+ from huggingface_hub import Repository
46
+ from models import Wav2Vec2Config, FlaxWav2Vec2ForCTC
47
+ from optax._src import linear_algebra
48
+ from transformers import (
49
+ AutoFeatureExtractor,
50
+ AutoProcessor,
51
+ AutoTokenizer,
52
+ HfArgumentParser,
53
+ TrainingArguments,
54
+ is_tensorboard_available,
55
+ set_seed,
56
+ )
57
+ from transformers.file_utils import get_full_repo_name
58
+ from transformers.utils import check_min_version
59
+ from transformers.utils.versions import require_version
60
+
61
+
62
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
63
+ check_min_version("4.17.0.dev0")
64
+
65
+ require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
66
+
67
+ logger = logging.getLogger(__name__)
68
+
69
+
70
+ @flax.struct.dataclass
71
+ class ModelArguments:
72
+ """
73
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
74
+ """
75
+
76
+ model_name_or_path: str = field(
77
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
78
+ )
79
+ config_name: Optional[str] = field(
80
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
81
+ )
82
+ tokenizer_name: Optional[str] = field(
83
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
84
+ )
85
+ feature_extractor_name: Optional[str] = field(
86
+ default=None, metadata={"help": "feature extractor name or path if not the same as model_name"}
87
+ )
88
+ cache_dir: Optional[str] = field(
89
+ default=None,
90
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
91
+ )
92
+ use_fast_tokenizer: bool = field(
93
+ default=True,
94
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
95
+ )
96
+ model_revision: str = field(
97
+ default="main",
98
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
99
+ )
100
+ use_auth_token: bool = field(
101
+ default=False,
102
+ metadata={
103
+ "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
104
+ "with private models)."
105
+ },
106
+ )
107
+ freeze_feature_encoder: bool = field(
108
+ default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
109
+ )
110
+ attention_dropout: float = field(
111
+ default=0.0, metadata={"help": "The dropout ratio for the attention probabilities."}
112
+ )
113
+ activation_dropout: float = field(
114
+ default=0.1,
115
+ metadata={
116
+ "help": "The hidden activation dropout probability in the embeddings, encoder, and pooler."
117
+ },
118
+ )
119
+ hidden_dropout: float = field(
120
+ default=0.1,
121
+ metadata={
122
+ "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler."
123
+ },
124
+ )
125
+ feat_proj_dropout: float = field(
126
+ default=0.0,
127
+ metadata={
128
+ "help": "The feat proj dropout probability for feature encoder representations."
129
+ },
130
+ )
131
+ final_dropout: float = field(
132
+ default=0.0,
133
+ metadata={"help": "The dropout probability for the final projection layer."},
134
+ )
135
+ mask_time_prob: float = field(
136
+ default=0.1,
137
+ metadata={
138
+ "help": "The spec aug dropout probability for feature encoder representations."
139
+ },
140
+ )
141
+ mask_time_length: int = field(
142
+ default=10,
143
+ metadata={"help": "Length of vector span to mask along the time axis."},
144
+ )
145
+ mask_feature_prob: float = field(
146
+ default=0.0,
147
+ metadata={
148
+ "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector"
149
+ "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis."
150
+ },
151
+ )
152
+ mask_feature_length: int = field(
153
+ default=10,
154
+ metadata={"help": "Length of vector span to mask along the feature axis."},
155
+ )
156
+ layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."})
157
+ ctc_loss_reduction: Optional[str] = field(
158
+ default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
159
+ )
160
+ ctc_zero_infinity: Optional[bool] = field(
161
+ default=False, metadata={"help": "If True, will try yo aboud the CTC loss goinf to infinity."}
162
+ )
163
+
164
+
165
+ @flax.struct.dataclass
166
+ class DataTrainingArguments:
167
+ """
168
+ Arguments pertaining to what data we are going to input our model for training and eval.
169
+ """
170
+
171
+ dataset_name: str = field(
172
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
173
+ )
174
+ dataset_config_name: Optional[str] = field(
175
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
176
+ )
177
+ text_column: Optional[str] = field(
178
+ default=None,
179
+ metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
180
+ )
181
+ dataset_cache_dir: Optional[str] = field(
182
+ default=None, metadata={"help": "Path to cache directory for saving and loading datasets"}
183
+ )
184
+ overwrite_cache: bool = field(
185
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
186
+ )
187
+ preprocessing_num_workers: Optional[int] = field(
188
+ default=None,
189
+ metadata={"help": "The number of processes to use for the preprocessing."},
190
+ )
191
+ max_train_samples: Optional[int] = field(
192
+ default=None,
193
+ metadata={
194
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
195
+ "value if set."
196
+ },
197
+ )
198
+ max_eval_samples: Optional[int] = field(
199
+ default=None,
200
+ metadata={
201
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
202
+ "value if set."
203
+ },
204
+ )
205
+ max_test_samples: Optional[int] = field(
206
+ default=None,
207
+ metadata={
208
+ "help": "For debugging purposes or quicker training, truncate the number of test examples to this "
209
+ "value if set."
210
+ },
211
+ )
212
+ audio_column_name: str = field(
213
+ default="audio",
214
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
215
+ )
216
+ text_column_name: str = field(
217
+ default="text",
218
+ metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
219
+ )
220
+ max_duration_in_seconds: float = field(
221
+ default=20.0,
222
+ metadata={
223
+ "help": "Truncate audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`"
224
+ },
225
+ )
226
+ min_duration_in_seconds: float = field(
227
+ default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
228
+ )
229
+ max_label_length: Optional[int] = field(
230
+ default=512,
231
+ metadata={
232
+ "help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
233
+ "than this will be filtered."
234
+ },
235
+ )
236
+ min_label_length: Optional[int] = field(
237
+ default=2,
238
+ metadata={
239
+ "help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
240
+ "than this will be filtered."
241
+ },
242
+ )
243
+ pad_input_to_multiple_of: Optional[int] = field(
244
+ default=32000,
245
+ metadata={
246
+ "help": "If set will pad the input sequence to a multiple of the provided value. "
247
+ "This is important to avoid triggering recompilations on TPU."
248
+ },
249
+ )
250
+ pad_target_to_multiple_of: Optional[int] = field(
251
+ default=None,
252
+ metadata={
253
+ "help": "If set will pad the target sequence to a multiple of the provided value. "
254
+ "This is important to avoid triggering recompilations on TPU."
255
+ },
256
+ )
257
+ preprocessing_only: bool = field(
258
+ default=False,
259
+ metadata={
260
+ "help": "Whether to only do data preprocessing and skip training. "
261
+ "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
262
+ "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
263
+ "so that the cached datasets can consequently be loaded in distributed training"
264
+ },
265
+ )
266
+ train_split_name: str = field(
267
+ default="train",
268
+ metadata={
269
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
270
+ },
271
+ )
272
+ eval_split_name: str = field(
273
+ default="validation",
274
+ metadata={
275
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
276
+ },
277
+ )
278
+ do_lower_case: bool = field(
279
+ default=True,
280
+ metadata={"help": "Whether the target text should be lower cased."},
281
+ )
282
+ wandb_project: str = field(
283
+ default="flax-speech-recognition-ctc",
284
+ metadata={"help": "The name of the wandb project."},
285
+ )
286
+ wandb_name: str = field(
287
+ default=None,
288
+ metadata={"help": "The name of the wandb run."},
289
+ )
290
+ wandb_job_type: str = field(
291
+ default="CTC",
292
+ metadata={"help": "The name of the wandb job type."},
293
+ )
294
+ test_split_name: str = field(
295
+ default="test",
296
+ metadata={"help": "The name of the test data set split to use (via the datasets library). Defaults to 'test'"},
297
+ )
298
+ remove_punctuation: bool = field(
299
+ default=False, metadata={"help": "Whether or not to remove punctuation during training."}
300
+ )
301
+ skip_steps: Optional[int] = field(
302
+ default=0,
303
+ metadata={
304
+ "help": "Skip this number of steps. Useful to continue training"
305
+ },
306
+ )
307
+
308
+
309
+ # @flax.struct.dataclass
310
+ @dataclass
311
+ class FlaxTrainingArguments(TrainingArguments):
312
+ precision: str = field(
313
+ default="full",
314
+ metadata={
315
+ "help": "Whether to enable mixed-precision training. If true, the optimizer is stored in half-precision (bfloat16) and computations are executed in half-precision"
316
+ "**Note that this only specifies the dtype of the computation and optimizer state. It does not influence the dtype of model parameters.**"
317
+ },
318
+ )
319
+ matmul_precision: str = field(
320
+ default="default",
321
+ metadata={
322
+ "help": "Default floating-point precision of internal computations used in TPU matrix multiplications and convolutions. "
323
+ "This configuration option controls the default precision for JAX operations that take an optional precision argument (e.g. `lax.conv_general_dilated` and `lax.dot`). "
324
+ "This configuration option does not change the behaviours of such calls with explicit precision arguments; "
325
+ "it only changes the behaviors of calls with no such argument provided. "
326
+ "One of `['highest', 'float32', 'high', 'bfloat16_3x', 'default', 'bfloat16', 'fastest', None]`."
327
+ },
328
+ )
329
+ multisteps: bool = field(
330
+ default=False,
331
+ metadata={
332
+ "help": "Whether to use Optax MultiSteps for gradient accumulation. If `False` (default) and `gradient_accumulation_steps > 1`, "
333
+ "a custom gradient accumulation implementation will be employed."
334
+ },
335
+ )
336
+
337
+
338
+ def to_fp32(t):
339
+ return jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
340
+
341
+
342
+ def to_bf16(t):
343
+ return jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t)
344
+
345
+
346
+ class MixedPrecisionTrainState(struct.PyTreeNode):
347
+ """Train state for use with a single Optax optimizer.
348
+ Adapted from flax train_state https://github.com/google/flax/blob/main/flax/training/train_state.py
349
+
350
+ Synopsis::
351
+
352
+ state = TrainState.create(
353
+ apply_fn=model.apply,
354
+ params=variables['params'],
355
+ tx=tx)
356
+ grad_fn = jax.grad(make_loss_fn(state.apply_fn))
357
+ for batch in data:
358
+ grads = grad_fn(state.params, batch)
359
+ state = state.apply_gradients(grads=grads)
360
+
361
+ Args:
362
+ step: Counter starts at 0 and is incremented by every call to
363
+ `.apply_gradients()`.
364
+ apply_fn: Usually set to `model.apply()`. Kept in this dataclass for
365
+ convenience to have a shorter params list for the `train_step()` function
366
+ in your training loop.
367
+ params: The parameters to be updated by `tx` and used by `apply_fn`.
368
+ tx: An Optax gradient transformation.
369
+ opt_state: The state for `tx`.
370
+ dropout_rng: PRNG key for stochastic operations.
371
+ bf16: Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training.
372
+ """
373
+
374
+ step: int
375
+ apply_fn: Callable = struct.field(pytree_node=False)
376
+ get_attention_mask_fn: Callable = struct.field(pytree_node=False)
377
+ params: core.FrozenDict[str, Any]
378
+ tx: optax.GradientTransformation = struct.field(pytree_node=False)
379
+ opt_state: optax.OptState
380
+ dropout_rng: jnp.ndarray
381
+ max_grad_norm: Optional[float] = 1.0
382
+
383
+ def apply_gradients(self, *, grads, to_dtype, **kwargs):
384
+ """Updates `step`, `params`, `opt_state` and `**kwargs` in return value.
385
+
386
+ Note that internally this function calls `.tx.update()` followed by a call
387
+ to `optax.apply_updates()` to update `params` and `opt_state`.
388
+
389
+ Args:
390
+ grads: Gradients that have the same pytree structure as `.params`.
391
+ **kwargs: Additional dataclass attributes that should be `.replace()`-ed.
392
+
393
+ Returns:
394
+ An updated instance of `self` with `step` incremented by one, `params`
395
+ and `opt_state` updated by applying `grads`, and additional attributes
396
+ replaced as specified by `kwargs`.
397
+ """
398
+
399
+ # clip gradients by global l2 norm
400
+ casted_max_grad_norm = to_dtype(self.max_grad_norm)
401
+ g_norm = linear_algebra.global_norm(grads)
402
+ g_norm = jnp.maximum(casted_max_grad_norm, g_norm)
403
+ grads = jax.tree_util.tree_map(lambda t: (t / g_norm) * casted_max_grad_norm, grads)
404
+
405
+ # perform update step in fp32 and subsequently downcast optimizer states if mixed precision training
406
+ # grads and opt_state in bf16 (need to upcast), params in fp32 (leave as is)
407
+ updates, new_opt_state = self.tx.update(to_fp32(grads), to_fp32(self.opt_state), self.params)
408
+
409
+ new_params = optax.apply_updates(self.params, updates)
410
+ return self.replace(
411
+ step=self.step + 1,
412
+ params=new_params,
413
+ opt_state=to_dtype(new_opt_state),
414
+ **kwargs,
415
+ )
416
+
417
+ @classmethod
418
+ def create(cls, *, apply_fn, params, tx, to_dtype, **kwargs):
419
+ """Creates a new instance with `step=0` and initialized `opt_state`."""
420
+ # downcast optimizer state to bf16 if mixed-precision training
421
+ opt_state = tx.init(to_dtype(params)) if tx is not None else None
422
+ return cls(
423
+ step=0,
424
+ apply_fn=apply_fn,
425
+ params=params,
426
+ tx=tx,
427
+ opt_state=opt_state,
428
+ **kwargs,
429
+ )
430
+
431
+ def replicate(self):
432
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
433
+
434
+
435
+ @flax.struct.dataclass
436
+ class FlaxDataCollatorSpeechSeq2SeqWithPadding:
437
+ """
438
+ Data collator that will dynamically pad the inputs received.
439
+ Args:
440
+ processor ([`Wav2Vec2Processor`])
441
+ The processor used for proccessing the data.
442
+ decoder_start_token_id (:obj: `int`)
443
+ The begin-of-sentence of the decoder.
444
+ input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
445
+ Select a strategy to pad the returned input sequences (according to the model's padding side and padding index)
446
+ among:
447
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
448
+ sequence if provided).
449
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
450
+ maximum acceptable input length for the model if that argument is not provided.
451
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
452
+ different lengths).
453
+ target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
454
+ Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
455
+ See above for details.
456
+ max_input_length (:obj:`float`, `optional`):
457
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
458
+ pad_input_to_multiple_of (:obj:`int`, `optional`):
459
+ If set will pad the input sequence to a multiple of the provided value.
460
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
461
+ 7.5 (Volta).
462
+ pad_target_to_multiple_of (:obj:`int`, `optional`):
463
+ If set will pad the target sequence to a multiple of the provided value.
464
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
465
+ 7.5 (Volta).
466
+ """
467
+
468
+ processor: Any
469
+ input_padding: Union[bool, str] = "longest"
470
+ label_padding: Union[bool, str] = "max_length"
471
+ pad_input_to_multiple_of: Optional[int] = None
472
+ pad_to_multiple_of_label: Optional[int] = None
473
+ max_input_length: Optional[float] = None
474
+ max_label_length: Optional[float] = None
475
+
476
+ def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
477
+ # split inputs and labels since they have to be of different lengths and need
478
+ # different padding methods
479
+ input_features = [{"input_values": feature["input_values"]} for feature in features]
480
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
481
+
482
+ # reformat list to dict and set to pytorch format
483
+ batch = self.processor.feature_extractor.pad(
484
+ input_features,
485
+ max_length=self.max_input_length,
486
+ padding=self.input_padding,
487
+ pad_to_multiple_of=self.pad_input_to_multiple_of,
488
+ return_tensors="np",
489
+ )
490
+
491
+ labels_batch = self.processor.tokenizer.pad(
492
+ label_features,
493
+ max_length=self.max_label_length,
494
+ padding=self.label_padding,
495
+ pad_to_multiple_of=self.pad_to_multiple_of_label,
496
+ return_tensors="np",
497
+ )
498
+
499
+ labels = labels_batch["input_ids"]
500
+ labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1))
501
+ labels = labels.filled(fill_value=-100)
502
+
503
+ batch["labels"] = labels
504
+
505
+ return batch
506
+
507
+
508
+ def get_grouped_indices(
509
+ dataset, batch_size: int, rng: Optional[List[int]] = None, mega_batch_mult: Optional[int] = None
510
+ ) -> np.array:
511
+ """
512
+ Adapted from the `get_length_grouped_indices` function in the PyTorch Trainer utils file (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_pt_utils.py#L486)
513
+ Function that returns a list of indices in which each slice of `batch_size` consecutive indices correspond to elements of similar
514
+ lengths. To do this, the indices are:
515
+
516
+ - randomly permuted (if a JAX rng is specified)
517
+ - grouped in mega-batches of size `mega_batch_mult * batch_size`
518
+ - sorted by length in each mega-batch
519
+
520
+ The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
521
+ maximum length placed first, so that an OOM happens sooner rather than later.
522
+ """
523
+ lengths = dataset["input_length"]
524
+
525
+ # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
526
+ if mega_batch_mult is None:
527
+ mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
528
+ # Just in case, for tiny datasets
529
+ if mega_batch_mult == 0:
530
+ mega_batch_mult = 1
531
+
532
+ # We need to use JAX for the random permutation as the PRNG key will be set based on the seed outside of the sampler.
533
+ num_samples = len(lengths)
534
+ indices = jax.random.permutation(rng, np.arange(num_samples)) if rng is not None else np.arange(num_samples)
535
+
536
+ megabatch_size = mega_batch_mult * batch_size
537
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
538
+ megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches]
539
+
540
+ # The rest is to get the biggest batch first.
541
+ # Since each megabatch is sorted by descending length, the longest element is the first
542
+ megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
543
+ max_idx = np.argmax(megabatch_maximums).item()
544
+ # Switch to put the longest batch in first position
545
+ # (note that this is different to the PT grouped sampler in which we only put the longest element in the first position, and not its batch)
546
+ megabatches[0], megabatches[max_idx] = megabatches[max_idx], megabatches[0]
547
+
548
+ megabatches = np.array([i for megabatch in megabatches for i in megabatch])
549
+
550
+ return megabatches
551
+
552
+
553
+ def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray:
554
+ """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
555
+ the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned."""
556
+ num_samples = len(samples_idx)
557
+ if drop_last:
558
+ samples_to_remove = num_samples % batch_size
559
+ if samples_to_remove != 0:
560
+ samples_idx = samples_idx[:-samples_to_remove]
561
+ sections_split = num_samples // batch_size
562
+ samples_idx = samples_idx.reshape((sections_split, batch_size))
563
+ else:
564
+ sections_split = math.ceil(num_samples / batch_size)
565
+ samples_idx = np.array_split(samples_idx, sections_split)
566
+ return samples_idx
567
+
568
+
569
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
570
+ summary_writer.scalar("train_time", train_time, step)
571
+
572
+ train_metrics = get_metrics(train_metrics)
573
+ for key, vals in train_metrics.items():
574
+ tag = f"train_{key}"
575
+ for i, val in enumerate(vals):
576
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
577
+
578
+
579
+ def write_eval_metric(summary_writer, eval_metrics, step, pred_str=None):
580
+ for metric_name, value in eval_metrics.items():
581
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
582
+
583
+ if pred_str is not None:
584
+ # write output actual predictions for debugging
585
+ summary_writer.text("eval_predictions", "\n".join(pred_str), step)
586
+
587
+
588
+ def write_wandb_log(metrics, step, prefix=None):
589
+ if jax.process_index() == 0:
590
+ log_metrics = {}
591
+ for k, v in metrics.items():
592
+ if "layer" in k:
593
+ log_metrics[f"{k}/"] = v
594
+ elif prefix is not None:
595
+ log_metrics[f"{prefix}/{k}"] = v
596
+ else:
597
+ log_metrics[k] = v
598
+ wandb.log(log_metrics, step)
599
+
600
+
601
+ def write_wandb_pred(pred_str, label_str, step, num_log=50, prefix="eval"):
602
+ if jax.process_index() == 0:
603
+ # convert str data to a wandb compatible format
604
+ str_data = [[label_str[i], pred_str[i]] for i in range(len(pred_str))]
605
+ # we'll log the first 50 predictions for each epoch
606
+ wandb.log(
607
+ {
608
+ f"{prefix}/step_{int(step / 1000)}k": wandb.Table(
609
+ columns=["label_str", "pred_str"], data=str_data[:num_log]
610
+ )
611
+ },
612
+ step,
613
+ )
614
+
615
+
616
+ def create_learning_rate_fn(
617
+ num_train_steps: int, num_warmup_steps: int, learning_rate: float
618
+ ) -> Callable[[int], jnp.array]:
619
+ """Returns a linear warmup, linear_decay learning rate function."""
620
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
621
+ decay_fn = optax.linear_schedule(
622
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
623
+ )
624
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
625
+ return schedule_fn
626
+
627
+
628
+ def ctc_loss(
629
+ logits,
630
+ logits_attention_mask,
631
+ labels,
632
+ blank_id,
633
+ loss_reduction="mean",
634
+ output_emission_dict=False,
635
+ log_epsilon=-100000.0,
636
+ ):
637
+ """Computes CTC loss.
638
+ This function performs forward computation over an FSA with `N * 2` states
639
+ where `N` is the max number of labels. The states are split into two groups:
640
+ Phi states and emission states. a phi-state accepts repetition of
641
+ phi (blank)-symbols and transits to emission state when the correct label is
642
+ observed. An emission state accepts repetition of the label and transits to
643
+ the next phi states at any time (so called epsilon-transition).
644
+ Below, `B` denotes the batch size, `T` denotes the time steps in `logits`,
645
+ and `N` denotes the time steps in `labels`.
646
+ Args:
647
+ logits: (B, T, K)-array containing log-probabilities of each class.
648
+ logitpaddings: (B, T)-array. Padding indicators for `logits`.
649
+ labels: (B, N)-array containing reference integer labels.
650
+ labelpaddings: (B, N)-array. Padding indicators for `labels`. Currently,
651
+ `labels` must be right-padded, i.e. each row of `labelpaddings` must be
652
+ repetition of zeroes, followed by repetition of ones.
653
+ blank_id: Id for blank token.
654
+ loss_reduction: one of "mean", "sum", "default"
655
+ - "none": no reduction is applied.
656
+ - "mean": output loss will be divided by target lengths and then the
657
+ mean over the batch is taken.
658
+ - "sum": output loss are summed over batch
659
+ output_emission_dict: whether to output additional information about the emission probs
660
+ Returns:
661
+ A pair of `(per_seq_loss, aux)`.
662
+ per_seq_loss:
663
+ (B,)-array containing loss values for each sequence in the batch.
664
+ aux: Dictionary containing interim variables used for computing losses.
665
+ aux['logalpha_phi']: (T, B, N+1)-array. Log-forward-probabilities of each
666
+ phi-state corresponding to the n-th label.
667
+ aux['logalpha_emit']: (T, B, N)-array. Log-forward-probabilities of each
668
+ emission-state corresponding to the n-th label.
669
+ aux['logprobs_phi']: (T, B, 1)-array. Probability of the phi-symbol
670
+ corresponding to each time frame.
671
+ aux['logprobs_emit']: (T, B, N)-array. Probability of the n-th label
672
+ corresponding to each time frame.
673
+ """
674
+ # label paddings are indicated by -100
675
+ labelpaddings = labels < 0
676
+ # logit paddings are the inverse of attention_mask
677
+ logitpaddings = ~logits_attention_mask
678
+
679
+ # Copied from https://github.com/tensorflow/lingvo/blob/master/lingvo/jax/layers/ctc_objectives.py
680
+ batchsize, unused_maxinputlen, num_classes = logits.shape
681
+ batchsize_, maxlabellen = labels.shape
682
+
683
+ logprobs = jax.nn.log_softmax(logits)
684
+ labellens = maxlabellen - jnp.sum(labelpaddings, axis=1).astype(jnp.int32)
685
+
686
+ # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1].
687
+ repeat = (labels[:, :-1] == labels[:, 1:]).astype(jnp.float32)
688
+ repeat = jnp.pad(repeat, ((0, 0), (0, 1)))
689
+
690
+ logprobs_phi = logprobs[:, :, blank_id : blank_id + 1] # [B, T, 1]
691
+ logprobs_phi = jnp.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1]
692
+
693
+ one_hot = jax.nn.one_hot(labels, num_classes=num_classes) # [B, N, K]
694
+ logprobs_emit = jnp.einsum("btk,bnk->btn", logprobs, one_hot)
695
+ logprobs_emit = jnp.transpose(logprobs_emit, (1, 0, 2)) # [T, B, N]
696
+
697
+ logalpha_phi_init = jnp.ones((batchsize, maxlabellen + 1)) * log_epsilon # [B, N]
698
+ logalpha_phi_init = logalpha_phi_init.at[:, 0].set(0.0)
699
+ logalpha_emit_init = jnp.ones((batchsize, maxlabellen)) * log_epsilon # [B, N]
700
+
701
+ def loop_body(prev, x):
702
+ prev_phi, prev_emit = prev
703
+ # emit-to-phi epsilon transition, except if the next label is repetition
704
+ prev_phi_orig = prev_phi
705
+ prev_phi = prev_phi.at[:, 1:].set(jnp.logaddexp(prev_phi[:, 1:], prev_emit + log_epsilon * repeat))
706
+
707
+ logprob_emit, logprob_phi, pad = x
708
+
709
+ # phi-to-emit transition
710
+ next_emit = jnp.logaddexp(prev_phi[:, :-1] + logprob_emit, prev_emit + logprob_emit)
711
+ # self-loop transition
712
+ next_phi = prev_phi + logprob_phi
713
+ # emit-to-phi blank transition only when the next label is repetition
714
+ next_phi = next_phi.at[:, 1:].set(
715
+ jnp.logaddexp(next_phi[:, 1:], prev_emit + logprob_phi + log_epsilon * (1.0 - repeat))
716
+ )
717
+
718
+ pad = pad.reshape((batchsize, 1))
719
+ next_emit = pad * prev_emit + (1.0 - pad) * next_emit
720
+ next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi
721
+
722
+ return (next_phi, next_emit), (next_phi, next_emit)
723
+
724
+ xs = (logprobs_emit, logprobs_phi, logitpaddings.transpose((1, 0)))
725
+ _, (logalpha_phi, logalpha_emit) = jax.lax.scan(loop_body, (logalpha_phi_init, logalpha_emit_init), xs)
726
+
727
+ # last row needs to be updated with the last epsilon transition
728
+ logalpha_phi_last = logalpha_phi[-1].at[:, 1:].set(jnp.logaddexp(logalpha_phi[-1, :, 1:], logalpha_emit[-1]))
729
+ logalpha_phi = logalpha_phi.at[-1].set(logalpha_phi_last)
730
+
731
+ # extract per_seq_loss
732
+ one_hot = jax.nn.one_hot(labellens, num_classes=maxlabellen + 1) # [B, N+1]
733
+ per_seq_loss = -jnp.einsum("bn,bn->b", logalpha_phi_last, one_hot)
734
+
735
+ if loss_reduction == "mean":
736
+ target_lengths = labelpaddings.shape[-1] - labelpaddings.sum(axis=-1)
737
+ loss = (per_seq_loss / target_lengths).mean()
738
+ elif loss_reduction == "sum":
739
+ loss = per_seq_loss.sum()
740
+ else:
741
+ loss = per_seq_loss
742
+
743
+ if not output_emission_dict:
744
+ return loss
745
+
746
+ return loss, {
747
+ "logalpha_phi": logalpha_phi,
748
+ "logalpha_emit": logalpha_emit,
749
+ "logprobs_phi": logprobs_phi,
750
+ "logprobs_emit": logprobs_emit,
751
+ }
752
+
753
+
754
+ def make_dataset(data_args, seed=42):
755
+ # Pre-processing dataset
756
+ import re
757
+
758
+ def map_nst(entry):
759
+ text = entry["text"].lower()
760
+ text = text.replace("(...vær stille under dette opptaket...)", "")
761
+ text = re.sub('[áàâ]', 'a', text)
762
+ text = re.sub('[ä]', 'æ', text)
763
+ text = re.sub('[éèëê]', 'e', text)
764
+ text = re.sub('[íìïî]', 'i', text)
765
+ text = re.sub('[óòöô]', 'o', text)
766
+ text = re.sub('[ö]', 'ø', text)
767
+ text = re.sub('[ç]', 'c', text)
768
+ text = re.sub('[úùüû]', 'u', text)
769
+ # text = re.sub('\\(?=(Punktum|Komma|Utropstegn|Spørsmålstegn))', ' ', text)
770
+ text = re.sub('\s+', ' ', text)
771
+ return {"text": text}
772
+
773
+ def filter_nst(entry):
774
+ if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)):
775
+ return False # Too short
776
+ if re.match(entry["type"], "pIW|CA"):
777
+ return False # Spelling out words
778
+ return True
779
+
780
+ def filter_npsc(entry):
781
+ # False if there are digits in the text
782
+ if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)):
783
+ return False # Too short
784
+ if re.search("\d", entry["text"]):
785
+ return False
786
+ return True
787
+
788
+ def map_npsc(entry):
789
+ batch = {"text": entry["text"].lower()}
790
+ batch["text"] = re.sub('[áàâ]', 'a', batch["text"])
791
+ batch["text"] = re.sub('[ä]', 'æ', batch["text"])
792
+ batch["text"] = re.sub('[éèëê]', 'e', batch["text"])
793
+ batch["text"] = re.sub('[íìïî]', 'i', batch["text"])
794
+ batch["text"] = re.sub('[óòöô]', 'o', batch["text"])
795
+ batch["text"] = re.sub('[ö]', 'ø', batch["text"])
796
+ batch["text"] = re.sub('[ç]', 'c', batch["text"])
797
+ batch["text"] = re.sub('[úùüû]', 'u', batch["text"])
798
+ batch["text"] = re.sub('\s', ' ', batch["text"])
799
+ batch["text"] = re.sub('<ee>', 'eee', batch["text"])
800
+ batch["text"] = re.sub('<qq>', 'qqq', batch["text"])
801
+ batch["text"] = re.sub('<mm>', 'mmm', batch["text"])
802
+ batch["text"] = re.sub('<inaudible>', 'xxx', batch["text"])
803
+ # batch["text"] = re.sub('<inaudible>', '?', batch["text"])
804
+ if "<" in batch["text"]:
805
+ raise ValueError(batch["text"])
806
+ return batch
807
+
808
+ nst = datasets.load_dataset("NbAiLab/NST", "no-close")
809
+ npsc = datasets.load_dataset("NbAiLab/NPSC", "16K_mp3")
810
+ # TODO NST_hesitate
811
+
812
+ split = len(npsc[data_args.train_split_name]) / (len(npsc[data_args.train_split_name]) + len(npsc[data_args.eval_split_name])) # Use same train/val ratio as NPSC
813
+ nst_train = nst[data_args.train_split_name].train_test_split(train_size=split, seed=seed)
814
+ nst[data_args.train_split_name] = nst_train["train"]
815
+ nst[data_args.eval_split_name] = nst_train["test"]
816
+
817
+ nst = nst.filter(filter_nst).map(
818
+ map_nst,
819
+ num_proc=data_args.preprocessing_num_workers,
820
+ desc="filtering NST",
821
+ ).shuffle(seed=seed)
822
+ npsc = npsc.filter(filter_npsc).map(
823
+ map_npsc,
824
+ num_proc=data_args.preprocessing_num_workers,
825
+ desc="filtering NPSC",
826
+ ).shuffle(seed=seed)
827
+
828
+ npsc_base = npsc.remove_columns([col for col in npsc[data_args.train_split_name].column_names if col not in ["text", "audio"]])
829
+ nst_base = nst.remove_columns([col for col in nst[data_args.train_split_name].column_names if col not in ["text", "audio"]])
830
+
831
+ combined = {}
832
+ for split in data_args.train_split_name, data_args.eval_split_name, data_args.test_split_name:
833
+ probs = np.array([len(nst_base[split]), len(npsc_base[split])]) # Weight by number of examples
834
+ probs = (probs / probs.sum()).tolist()
835
+ comb = datasets.interleave_datasets([nst_base[split], npsc_base[split]], probabilities=probs, seed=seed)
836
+ combined[split] = comb
837
+
838
+ return datasets.DatasetDict(**combined)
839
+
840
+ def main():
841
+ # 1. Parse input arguments
842
+ # See all possible arguments in src/transformers/training_args.py
843
+ # or by passing the --help flag to this script.
844
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
845
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FlaxTrainingArguments))
846
+
847
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
848
+ # If we pass only one argument to the script and it's the path to a json file,
849
+ # let's parse it to get our arguments.
850
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
851
+ else:
852
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
853
+
854
+ # 2. Setup logging
855
+ # Make one log on every process with the configuration for debugging.
856
+ logging.basicConfig(
857
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
858
+ datefmt="%m/%d/%Y %H:%M:%S",
859
+ handlers=[logging.StreamHandler(sys.stdout)],
860
+ )
861
+ # Set the verbosity to info of the Transformers logger.
862
+ # We only want one process per machine to log things on the screen.
863
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
864
+ if jax.process_index() == 0:
865
+ datasets.utils.logging.set_verbosity_warning()
866
+ transformers.utils.logging.set_verbosity_info()
867
+ else:
868
+ datasets.utils.logging.set_verbosity_error()
869
+ transformers.utils.logging.set_verbosity_error()
870
+
871
+ # Set up wandb run
872
+ if jax.process_index() == 0:
873
+ wandb.init(project=data_args.wandb_project, name=data_args.wandb_name, job_type=data_args.wandb_job_type)
874
+
875
+ logger.info("Training/evaluation parameters %s", training_args)
876
+
877
+ # Set the default TPU matmul precision and display the number of devices
878
+ jax.config.update("jax_default_matmul_precision", training_args.matmul_precision)
879
+ logger.info(f"JAX devices: {jax.device_count()}, matmul precision: {training_args.matmul_precision}")
880
+
881
+ # 4. Load dataset
882
+
883
+ set_seed(training_args.seed)
884
+ raw_datasets = make_dataset(data_args, seed=training_args.seed)
885
+
886
+ # raw_datasets = DatasetDict()
887
+
888
+ # if training_args.do_train:
889
+ # raw_datasets[data_args.train_split_name] = load_dataset(
890
+ # data_args.dataset_name,
891
+ # data_args.dataset_config_name,
892
+ # split=data_args.train_split_name,
893
+ # cache_dir=data_args.dataset_cache_dir,
894
+ # use_auth_token=True if model_args.use_auth_token else None,
895
+ # )
896
+
897
+ # if training_args.do_eval:
898
+ # raw_datasets[data_args.eval_split_name] = load_dataset(
899
+ # data_args.dataset_name,
900
+ # data_args.dataset_config_name,
901
+ # split=data_args.eval_split_name,
902
+ # cache_dir=data_args.dataset_cache_dir,
903
+ # use_auth_token=True if model_args.use_auth_token else None,
904
+ # )
905
+
906
+ # if training_args.do_predict:
907
+ # test_split = data_args.test_split_name.split("+")
908
+ # for split in test_split:
909
+ # raw_datasets[split] = load_dataset(
910
+ # data_args.dataset_name,
911
+ # data_args.dataset_config_name,
912
+ # split=split,
913
+ # cache_dir=data_args.dataset_cache_dir,
914
+ # use_auth_token=True if model_args.use_auth_token else None,
915
+ # )
916
+
917
+ if not training_args.do_train and not training_args.do_eval and not training_args.do_predict:
918
+ raise ValueError(
919
+ "Cannot not train, not do evaluation and not do prediction. At least one of "
920
+ "training, evaluation or prediction has to be done."
921
+ )
922
+
923
+ # if not training, there is no need to run multiple epochs
924
+ if not training_args.do_train:
925
+ training_args.num_train_epochs = 1
926
+
927
+ if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
928
+ raise ValueError(
929
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
930
+ "Make sure to set `--audio_column_name` to the correct audio column - one of "
931
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
932
+ )
933
+
934
+ if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names:
935
+ raise ValueError(
936
+ f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
937
+ "Make sure to set `--text_column_name` to the correct text column - one of "
938
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
939
+ )
940
+
941
+ # 5. Load pretrained model, tokenizer, and feature extractor
942
+ #
943
+ # Distributed training:
944
+ # The .from_pretrained methods guarantee that only one local process can concurrently
945
+ config = Wav2Vec2Config.from_pretrained(
946
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
947
+ cache_dir=model_args.cache_dir,
948
+ revision=model_args.model_revision,
949
+ use_auth_token=True if model_args.use_auth_token else None,
950
+ )
951
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
952
+ model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
953
+ cache_dir=model_args.cache_dir,
954
+ revision=model_args.model_revision,
955
+ use_auth_token=True if model_args.use_auth_token else None,
956
+ )
957
+ tokenizer = AutoTokenizer.from_pretrained(
958
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
959
+ cache_dir=model_args.cache_dir,
960
+ revision=model_args.model_revision,
961
+ use_auth_token=True if model_args.use_auth_token else None,
962
+ )
963
+ # update config according to training args, model args, and tokenizer attributes
964
+ config.update(
965
+ {
966
+ "feat_proj_dropout": model_args.feat_proj_dropout,
967
+ "attention_dropout": model_args.attention_dropout,
968
+ "hidden_dropout": model_args.hidden_dropout,
969
+ "final_dropout": model_args.final_dropout,
970
+ "mask_time_prob": model_args.mask_time_prob,
971
+ "mask_time_length": model_args.mask_time_length,
972
+ "mask_feature_prob": model_args.mask_feature_prob,
973
+ "mask_feature_length": model_args.mask_feature_length,
974
+ "gradient_checkpointing": training_args.gradient_checkpointing,
975
+ "layerdrop": model_args.layerdrop,
976
+ "ctc_loss_reduction": model_args.ctc_loss_reduction,
977
+ "ctc_zero_infinity": model_args.ctc_zero_infinity,
978
+ "pad_token_id": tokenizer.pad_token_id,
979
+ "vocab_size": tokenizer.vocab_size, # len(tokenizer),
980
+ "activation_dropout": model_args.activation_dropout,
981
+ }
982
+ )
983
+
984
+ if tokenizer.do_lower_case and data_args.dataset_name != "librispeech_asr":
985
+ raise ValueError(
986
+ "Setting the tokenizer attribute `do_lower_case` to `True` converts all input strings to "
987
+ "uppercase prior to tokenization. This should only be done when the tokenizer is built on an uppercased corpus,"
988
+ "i.e. for the dataset `librispeech_asr` only. If your dataset is not `librispeech_asr`, the tokenizer is mostly likely "
989
+ "built on an lowercased corpus. In this case, set `tokenizer.do_lower_case` to ``False`."
990
+ )
991
+
992
+ if training_args.precision == "full_mixed":
993
+ dtype = jnp.bfloat16
994
+ training_args.mixed_precision = True
995
+ elif training_args.precision == "half_mixed":
996
+ dtype = jnp.bfloat16
997
+ training_args.mixed_precision = False
998
+ else:
999
+ dtype = jnp.float32
1000
+ training_args.mixed_precision = False
1001
+
1002
+ try:
1003
+ model = FlaxWav2Vec2ForCTC.from_pretrained(
1004
+ model_args.model_name_or_path,
1005
+ config=config,
1006
+ dtype=dtype,
1007
+ cache_dir=model_args.cache_dir,
1008
+ revision=model_args.model_revision,
1009
+ use_auth_token=True if model_args.use_auth_token else None,
1010
+ )
1011
+ except:
1012
+ model = FlaxWav2Vec2ForCTC.from_pretrained(
1013
+ model_args.model_name_or_path,
1014
+ config=config,
1015
+ dtype=dtype,
1016
+ cache_dir=model_args.cache_dir,
1017
+ revision=model_args.model_revision,
1018
+ use_auth_token=True if model_args.use_auth_token else None,
1019
+ from_pt=True,
1020
+ )
1021
+
1022
+ # 6. Resample speech dataset ALWAYS
1023
+ raw_datasets = raw_datasets.cast_column(
1024
+ data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
1025
+ )
1026
+
1027
+ # 7. Preprocessing the datasets.
1028
+ # We need to read the audio files as arrays and tokenize the targets.
1029
+ max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
1030
+ min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate)
1031
+ max_target_length = data_args.max_label_length
1032
+ min_target_length = data_args.min_label_length
1033
+ pad_input_to_multiple_of = data_args.pad_input_to_multiple_of
1034
+ audio_column_name = data_args.audio_column_name
1035
+ num_workers = data_args.preprocessing_num_workers
1036
+ text_column_name = data_args.text_column_name
1037
+ model_input_name = feature_extractor.model_input_names[0]
1038
+ do_lower_case = data_args.do_lower_case
1039
+ dataset_name = data_args.dataset_name
1040
+ chars_to_ignore = ', ? . ! - ; : " “ % ‘ ” ?'.split(" ")
1041
+ chars_to_ignore_regex = f'[{"".join(chars_to_ignore)}]'
1042
+ # gigaspeech_punctuation = {" <comma>": ",", " <period>": ".", " <questionmark>": "?", " <exclamationpoint>": "!"}
1043
+ # gigaspeech_disfluencies = ["<other>", "<sil>"]
1044
+ # swb_disfluencies = ["[noise]", "[laughter]", "[silence]", "<a_aside>", "<b_aside>", "<e_aside>", "[laughter-",
1045
+ # "[vocalized-noise]", "_1"]
1046
+ # swb_punctuations = ["{", "}", "[", "]-", "]"]
1047
+ # earnings_disfluencies = ["<crosstalk>", "<affirmative>", "<inaudible>", "inaudible", "<laugh>", "<unk>"]
1048
+ ignore_segments = ["ignore_time_segment_in_scoring", "<noise>", "<music>", "[noise]", "[laughter]", "[silence]",
1049
+ "[vocalized-noise]", "<crosstalk>", "<affirmative>", "<inaudible>", "<laugh>", "<other>", "<sil>", ""]
1050
+
1051
+ if training_args.do_train and data_args.max_train_samples is not None:
1052
+ raw_datasets[data_args.train_split_name] = raw_datasets[data_args.train_split_name].select(range(data_args.max_train_samples))
1053
+
1054
+ if training_args.do_eval and data_args.max_eval_samples is not None:
1055
+ raw_datasets[data_args.eval_split_name] = raw_datasets[data_args.eval_split_name].select(range(data_args.max_eval_samples))
1056
+
1057
+ if training_args.do_predict and data_args.max_test_samples is not None:
1058
+ raw_datasets[data_args.test_split_name] = raw_datasets[data_args.test_split_name].select(range(data_args.max_eval_samples))
1059
+
1060
+ if training_args.do_train and data_args.remove_punctuation:
1061
+
1062
+ def remove_punctuation(batch):
1063
+ batch[text_column_name] = (
1064
+ re.sub(chars_to_ignore_regex, "", batch[text_column_name]).replace("'", "").replace('"', "")
1065
+ )
1066
+
1067
+ raw_datasets[data_args.train_split_name] = raw_datasets[data_args.train_split_name].map(
1068
+ remove_punctuation,
1069
+ num_proc=data_args.preprocessing_num_workers,
1070
+ desc="removing punctuation from train split",
1071
+ )
1072
+
1073
+ # filter data where the targets are ignored in scoring
1074
+ def is_target_labels(input_str):
1075
+ return input_str.lower() not in ignore_segments
1076
+
1077
+ raw_datasets = raw_datasets.filter(
1078
+ is_target_labels,
1079
+ num_proc=num_workers,
1080
+ input_columns=[text_column_name],
1081
+ desc="filtering data where the targets are ignored in scoring",
1082
+ )
1083
+
1084
+ def prepare_dataset(batch):
1085
+ # process audio
1086
+ try:
1087
+ sample = batch[audio_column_name]
1088
+ except ValueError:
1089
+ sample = {"array": np.array([0.]), "sampling_rate": feature_extractor.sampling_rate}
1090
+ inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
1091
+ # process audio length
1092
+ batch[model_input_name] = inputs.input_values[0]
1093
+ batch["input_length"] = len(batch["input_values"])
1094
+
1095
+ # process targets
1096
+ input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
1097
+
1098
+ # if dataset_name == "google/xtreme_s":
1099
+ # # Finally, we tokenize the processed text
1100
+ # batch["labels"] = tokenizer(input_str).input_ids
1101
+ # batch["labels_length"] = len(batch["labels"])
1102
+ # return batch
1103
+
1104
+ # # Common Voice 9
1105
+ # if input_str.startswith('"') and input_str.endswith('"'):
1106
+ # # we can remove trailing quotation marks as they do not affect the transcription
1107
+ # input_str = input_str[1:-1]
1108
+ # # normalize quotation marks
1109
+ # input_str = re.sub(r'["“”]', '"', input_str)
1110
+ # # normalize apostrophes
1111
+ # input_str = re.sub(r"[’']", "'", input_str)
1112
+ # # normalize hyphens
1113
+ # input_str = re.sub(r"[—–]", "-", input_str)
1114
+ # # replace double quotation marks with single
1115
+ # input_str = input_str.replace('""', '"')
1116
+ # if dataset_name == "mozilla-foundation/common_voice_9_0" and len(input_str):
1117
+ # # for CV9, we'll normalize the text to always finish with punctuation
1118
+ # if input_str[-1] not in [".", "?", "!"]:
1119
+ # input_str = input_str + "."
1120
+
1121
+ # # TEDLIUM-3
1122
+ # # delete the <unk> token from the text and replace spaced apostrophes with un-spaced
1123
+ # input_str = input_str.replace("<unk>", "").replace(" '", "'")
1124
+
1125
+ # # GigaSpeech
1126
+ # for disfluency in gigaspeech_disfluencies:
1127
+ # input_str = input_str.replace(disfluency, "")
1128
+ # # convert spelled out punctuation to symbolic form
1129
+ # for punctuation, replacement in gigaspeech_punctuation.items():
1130
+ # input_str = input_str.replace(punctuation, replacement)
1131
+ # if dataset_name == "speechcolab/gigaspeech" and len(input_str):
1132
+ # # for GS, we'll normalize the text to always finish with punctuation
1133
+ # if input_str[-1] not in [".", "?", "!"]:
1134
+ # input_str = input_str + "."
1135
+
1136
+ # # SWB
1137
+ # for disfluency in swb_disfluencies:
1138
+ # input_str = input_str.replace(disfluency, "")
1139
+ # # remove parenthesised text (test data only)
1140
+ # input_str = re.sub("[\(].*?[\)]", "", input_str)
1141
+ # for punctuation in swb_punctuations:
1142
+ # input_str = input_str.replace(punctuation, "")
1143
+ # # replace anomalous words with their correct transcriptions
1144
+ # split_str = input_str.split("/")
1145
+ # if len(split_str) > 1:
1146
+ # input_str = " ".join(
1147
+ # [" ".join([" ".join(i.split(" ")[:-1]) for i in split_str])] + [split_str[-1].split(" ")[-1]])
1148
+
1149
+ # # Earnings 22
1150
+ # for disfluency in earnings_disfluencies:
1151
+ # input_str = input_str.replace(disfluency, "")
1152
+ # # replace mal-formatted ellipsis
1153
+ # input_str = input_str.replace("…", ".")
1154
+
1155
+ # JIWER compliance
1156
+ # remove multiple spaces
1157
+ input_str = re.sub(r"\s\s+", " ", input_str)
1158
+ # strip trailing spaces
1159
+ input_str = input_str.strip()
1160
+
1161
+ # Finally, we tokenize the processed text
1162
+ batch["labels"] = tokenizer(input_str).input_ids
1163
+ batch["labels_length"] = len(batch["labels"])
1164
+ return batch
1165
+
1166
+ vectorized_datasets = raw_datasets.map(
1167
+ prepare_dataset,
1168
+ remove_columns=next(iter(raw_datasets.values())).column_names,
1169
+ num_proc=num_workers,
1170
+ desc="preprocess dataset",
1171
+ )
1172
+
1173
+ # filter data with inputs shorter than min_input_length or longer than max_input_length
1174
+ def is_audio_in_length_range(length):
1175
+ return length > min_input_length and length < max_input_length
1176
+
1177
+ vectorized_datasets = vectorized_datasets.filter(
1178
+ is_audio_in_length_range,
1179
+ num_proc=num_workers,
1180
+ input_columns=["input_length"],
1181
+ )
1182
+
1183
+ # filter data with targets shorter than min_target_length or longer than max_target_length
1184
+ def is_labels_in_length_range(length):
1185
+ return length > min_target_length # and length < max_target_length
1186
+
1187
+ vectorized_datasets = vectorized_datasets.filter(
1188
+ is_labels_in_length_range,
1189
+ num_proc=num_workers,
1190
+ input_columns=["labels_length"],
1191
+ )
1192
+
1193
+ # for large datasets it is advised to run the preprocessing on a
1194
+ # single machine first with `args.preprocessing_only` since there will mostly likely
1195
+ # be a timeout when running the script in distributed mode.
1196
+ # In a second step `args.preprocessing_only` can then be set to `False` to load the
1197
+ # cached dataset
1198
+ if data_args.preprocessing_only:
1199
+ cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
1200
+ logger.info(f"Data preprocessing finished. Files cached at {cache}.")
1201
+ return
1202
+
1203
+ # 8. Load Metrics
1204
+ wer_metric = load_metric("wer")
1205
+ cer_metric = load_metric("cer")
1206
+
1207
+ def compute_metrics(pred_ids: List[List[int]], label_ids: List[List[int]]):
1208
+ padded_ids = np.where(np.asarray(label_ids) == -100, tokenizer.pad_token_id, np.asarray(label_ids))
1209
+
1210
+ pred_str = tokenizer.batch_decode(pred_ids)
1211
+ # we do not want to group tokens when computing the metrics
1212
+ label_str = tokenizer.batch_decode(padded_ids, group_tokens=False)
1213
+
1214
+ wer = wer_metric.compute(predictions=pred_str, references=label_str)
1215
+ cer = cer_metric.compute(predictions=pred_str, references=label_str)
1216
+
1217
+ return {"wer": wer, "cer": cer}, pred_str, label_str
1218
+
1219
+ # 9. save feature extractor, tokenizer and config
1220
+ feature_extractor.save_pretrained(training_args.output_dir)
1221
+ tokenizer.save_pretrained(training_args.output_dir)
1222
+ config.save_pretrained(training_args.output_dir)
1223
+
1224
+ processor = AutoProcessor.from_pretrained(training_args.output_dir)
1225
+
1226
+ data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding(
1227
+ processor=processor,
1228
+ input_padding="longest",
1229
+ pad_input_to_multiple_of=pad_input_to_multiple_of,
1230
+ max_label_length=data_args.max_label_length,
1231
+ )
1232
+
1233
+ # Enable tensorboard only on the master node
1234
+ has_tensorboard = is_tensorboard_available()
1235
+ if has_tensorboard and jax.process_index() == 0:
1236
+ try:
1237
+ from flax.metrics.tensorboard import SummaryWriter
1238
+
1239
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
1240
+ except ImportError as ie:
1241
+ has_tensorboard = False
1242
+ logger.warning(
1243
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
1244
+ )
1245
+ else:
1246
+ logger.warning(
1247
+ "Unable to display metrics through TensorBoard because the package is not installed: "
1248
+ "Please run `pip install tensorboard` to enable."
1249
+ )
1250
+
1251
+ # 10. Handle the repository creation
1252
+ if training_args.push_to_hub:
1253
+ with open(os.path.join(training_args.output_dir, ".gitattributes"), "r+") as f:
1254
+ git_lfs_extensions = f.read()
1255
+ if "*.wandb" not in git_lfs_extensions:
1256
+ f.write("*.wandb filter=lfs diff=lfs merge=lfs -text")
1257
+ if training_args.hub_model_id is None:
1258
+ repo_name = get_full_repo_name(
1259
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
1260
+ )
1261
+ else:
1262
+ repo_name = training_args.hub_model_id
1263
+ repo = Repository(training_args.output_dir, clone_from=repo_name)
1264
+
1265
+ # 11. Initialize our training
1266
+ rng = jax.random.PRNGKey(training_args.seed)
1267
+ rng, dropout_rng = jax.random.split(rng)
1268
+
1269
+ # Store some constants
1270
+ max_steps = int(training_args.max_steps)
1271
+ gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
1272
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
1273
+ batch_size_per_update = train_batch_size * gradient_accumulation_steps
1274
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
1275
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
1276
+ to_dtype = to_bf16 if training_args.mixed_precision else to_fp32
1277
+
1278
+ if training_args.do_train:
1279
+ num_train_samples = len(vectorized_datasets[data_args.train_split_name])
1280
+ steps_per_epoch = num_train_samples // batch_size_per_update
1281
+ if max_steps > 0:
1282
+ num_epochs = -(training_args.max_steps // -steps_per_epoch)
1283
+ total_train_steps = max_steps
1284
+ else:
1285
+ num_epochs = int(training_args.num_train_epochs)
1286
+ total_train_steps = steps_per_epoch * num_epochs
1287
+
1288
+ # Create learning rate schedule
1289
+ # Create learning rate schedule
1290
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
1291
+ total_train_steps,
1292
+ training_args.warmup_steps,
1293
+ training_args.learning_rate,
1294
+ )
1295
+
1296
+ # We use Optax's "masking" functionality to not apply weight decay
1297
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
1298
+ # mask boolean with the same structure as the parameters.
1299
+ # The mask is True for parameters that should be decayed.
1300
+ # Note that this mask is specifically adapted for FlaxWav2Vec2 and FlaxBart.
1301
+ # For FlaxT5, one should correct the layer norm parameter naming
1302
+ # accordingly - see `run_t5_mlm_flax.py` e.g.
1303
+ def decay_mask_fn(params):
1304
+ flat_params = traverse_util.flatten_dict(params)
1305
+ layer_norm_params = [
1306
+ (name, "scale")
1307
+ for name in ["layer_norm", "self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
1308
+ ]
1309
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
1310
+ return traverse_util.unflatten_dict(flat_mask)
1311
+
1312
+ if training_args.adafactor:
1313
+ # Create Adafactor optimizer
1314
+ optim = optax.adafactor(
1315
+ learning_rate=linear_decay_lr_schedule_fn,
1316
+ dtype_momentum=jnp.bfloat16 if training_args.mixed_precision else jnp.float32,
1317
+ weight_decay_rate=training_args.weight_decay,
1318
+ weight_decay_mask=decay_mask_fn,
1319
+ )
1320
+ else:
1321
+ # Create AdamW optimizer
1322
+ optim = optax.adamw(
1323
+ learning_rate=linear_decay_lr_schedule_fn,
1324
+ b1=training_args.adam_beta1,
1325
+ b2=training_args.adam_beta2,
1326
+ eps=training_args.adam_epsilon,
1327
+ weight_decay=training_args.weight_decay,
1328
+ mask=decay_mask_fn,
1329
+ )
1330
+
1331
+ # Optax MultiSteps for gradient accumulation. We'll only call this optimizer transformation if gradient accumulation is required (i.e. gradient accumulation steps > 1)
1332
+ if training_args.multisteps and gradient_accumulation_steps > 1:
1333
+ optim = optax.MultiSteps(optim, gradient_accumulation_steps, use_grad_mean=False)
1334
+ else:
1335
+ num_epochs = 0
1336
+ total_train_steps = 0
1337
+ num_train_samples = 0
1338
+ optim = None
1339
+
1340
+ # Setup train state
1341
+ state = MixedPrecisionTrainState.create(
1342
+ apply_fn=model.__call__,
1343
+ get_attention_mask_fn=model._get_feature_vector_attention_mask,
1344
+ params=model.params,
1345
+ tx=optim,
1346
+ to_dtype=to_dtype,
1347
+ dropout_rng=dropout_rng,
1348
+ max_grad_norm=training_args.max_grad_norm,
1349
+ )
1350
+
1351
+ # Replicate the train state on each device
1352
+ state = state.replicate()
1353
+ blank_id = model.config.pad_token_id
1354
+
1355
+ # Define gradient update step fn
1356
+ def train_step(state, batch):
1357
+ # only one single rng per grad step, with or without accumulation, as the graph should be identical over one effective training batch
1358
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
1359
+
1360
+ def compute_loss(params, minibatch):
1361
+ labels = minibatch.pop("labels")
1362
+ logits = state.apply_fn(
1363
+ **minibatch,
1364
+ params=params,
1365
+ dropout_rng=dropout_rng,
1366
+ freeze_feature_encoder=model_args.freeze_feature_encoder,
1367
+ train=True,
1368
+ )[0]
1369
+ logits_mask = state.get_attention_mask_fn(logits.shape[1], batch["attention_mask"])
1370
+ loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean")
1371
+
1372
+ return loss
1373
+
1374
+ grad_fn = jax.value_and_grad(compute_loss)
1375
+
1376
+ if gradient_accumulation_steps == 1 or training_args.multisteps:
1377
+ loss, grad = grad_fn(to_dtype(state.params), batch)
1378
+
1379
+ # Custom gradient accumulation
1380
+ else:
1381
+ # add a first dimension over gradient_accumulation_steps for minibatch slices
1382
+ batch = jax.tree_util.tree_map(
1383
+ lambda x: x.reshape(
1384
+ gradient_accumulation_steps, training_args.per_device_train_batch_size, *x.shape[1::]
1385
+ ),
1386
+ batch,
1387
+ )
1388
+
1389
+ def accum_minibatch_step(accum_grad, minibatch):
1390
+ # compute loss, num labels and grad over minibatch and accumulate
1391
+ loss, grad = grad_fn(to_dtype(state.params), minibatch)
1392
+ return jax.tree_util.tree_map(jnp.add, accum_grad, grad), loss
1393
+
1394
+ # create an initial state for accumulating losses, num labels and gradients
1395
+ init_grad = jax.tree_util.tree_map(jnp.zeros_like, to_dtype(state.params))
1396
+ # loop accum minibatch step over the number of gradient accumulation steps
1397
+ grad, loss = jax.lax.scan(accum_minibatch_step, init_grad, batch)
1398
+
1399
+ # update state
1400
+ new_state = state.apply_gradients(
1401
+ grads=grad,
1402
+ dropout_rng=new_dropout_rng,
1403
+ to_dtype=to_dtype,
1404
+ )
1405
+
1406
+ # compute gradient norms over all layers and globally for detailed monitoring
1407
+ layer_grad_norm = jax.tree_util.tree_map(jnp.linalg.norm, grad)
1408
+ logs = {
1409
+ "layer_grad_norm": layer_grad_norm,
1410
+ "grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm)),
1411
+ }
1412
+
1413
+ # compute parameter norms over all layers and globally for detailed monitoring
1414
+ layer_param_norm = jax.tree_util.tree_map(jnp.linalg.norm, new_state.params)
1415
+ logs["layer_param_norm"] = layer_param_norm
1416
+ logs["param_norm"] = jnp.linalg.norm(jax.tree_util.tree_leaves(layer_param_norm))
1417
+
1418
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
1419
+ metrics.update(logs)
1420
+
1421
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1422
+ # metrics = to_fp32(metrics)
1423
+
1424
+ return new_state, metrics
1425
+
1426
+ # Define eval fn
1427
+ def eval_step(params, batch):
1428
+ labels = batch.pop("labels")
1429
+ logits = model(**batch, params=params, train=False)[0]
1430
+
1431
+ logits_mask = model._get_feature_vector_attention_mask(logits.shape[1], batch["attention_mask"])
1432
+ loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean")
1433
+
1434
+ pred_ids = jnp.argmax(logits, axis=-1)
1435
+
1436
+ # summarize metrics
1437
+ metrics = {"loss": loss}
1438
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1439
+ # metrics = to_fp32(metrics)
1440
+ return metrics, pred_ids
1441
+
1442
+ # Create parallel version of the train and eval step
1443
+ if training_args.do_train:
1444
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
1445
+
1446
+ if training_args.do_eval:
1447
+ p_eval_step = jax.pmap(eval_step, "batch")
1448
+
1449
+ def run_evaluation(step):
1450
+ if training_args.do_eval:
1451
+ # ======================== Evaluating ==============================
1452
+ eval_metrics = []
1453
+ eval_preds = []
1454
+ eval_labels = []
1455
+
1456
+ # Generate eval set by sequentially sampling indices from the eval dataset and grouping by length
1457
+ eval_samples_idx = get_grouped_indices(vectorized_datasets[data_args.eval_split_name], eval_batch_size)
1458
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
1459
+
1460
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
1461
+ samples = [vectorized_datasets[data_args.eval_split_name][int(idx)] for idx in batch_idx]
1462
+ batch = data_collator(samples)
1463
+ labels = batch["labels"]
1464
+
1465
+ metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1466
+ eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1])))
1467
+ eval_metrics.append(metrics)
1468
+
1469
+ eval_labels.extend(labels)
1470
+
1471
+ # normalize eval metrics
1472
+ eval_metrics = get_metrics(eval_metrics)
1473
+ eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
1474
+ eval_metrics = to_fp32(eval_metrics)
1475
+
1476
+ # always run compute metrics
1477
+ error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels)
1478
+ eval_metrics.update(error_rate_metric)
1479
+ error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()])
1480
+
1481
+ # Print metrics and update progress bar
1482
+ desc = f"Step... ({step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})"
1483
+ epochs.write(desc)
1484
+ epochs.desc = desc
1485
+
1486
+ # Save metrics
1487
+ write_wandb_log(eval_metrics, step, prefix="eval")
1488
+ write_wandb_pred(pred_str, label_str, step)
1489
+ # if has_tensorboard and jax.process_index() == 0:
1490
+ # write_eval_metric(summary_writer, eval_metrics, step, pred_str=pred_str)
1491
+
1492
+ def save_checkpoint(step):
1493
+ # save and push checkpoint to the hub
1494
+ if jax.process_index() == 0:
1495
+ params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
1496
+ model.save_pretrained(training_args.output_dir, params=params)
1497
+ tokenizer.save_pretrained(training_args.output_dir)
1498
+ if training_args.push_to_hub:
1499
+ repo.push_to_hub(commit_message=f"{wandb.run.id}: saving weights and logs of step {int(step / 1000)}k", blocking=False)
1500
+
1501
+ skip_epochs = data_args.skip_steps // (num_train_samples // batch_size_per_update)
1502
+ logger.info("***** Running training *****")
1503
+ logger.info(f" Num examples = {num_train_samples}")
1504
+ logger.info(f" Num Epochs = {num_epochs}")
1505
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
1506
+ logger.info(f" Num gradient accumulation steps = {gradient_accumulation_steps}")
1507
+ logger.info(f" Total train batch size (w. parallel & distributed) = {batch_size_per_update}")
1508
+ logger.info(f" Total optimization steps = {total_train_steps}")
1509
+ logger.info(f" Gradient checkpointing: {config.gradient_checkpointing}")
1510
+ logger.info(f" Use scan: {config.use_scan}")
1511
+ logger.info(f" Fuse matmuls: {config.fuse_matmuls}")
1512
+ logger.info(f" Skipping: {data_args.skip_steps} steps ({skip_epochs} epochs)")
1513
+
1514
+ train_time = cur_step = 0
1515
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
1516
+ for epoch in epochs:
1517
+ if training_args.do_train:
1518
+ # ======================== Training ================================
1519
+ train_start = time.time()
1520
+
1521
+ if epoch < skip_epochs:
1522
+ logger.info(f"Skipping epoch {epoch + 1}")
1523
+ continue
1524
+
1525
+ # Create sampling rng
1526
+ rng, input_rng = jax.random.split(rng)
1527
+
1528
+ # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length
1529
+ train_samples_idx = get_grouped_indices(vectorized_datasets[data_args.train_split_name], batch_size_per_update, input_rng)
1530
+ train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update)
1531
+
1532
+ if data_args.skip_steps > cur_step:
1533
+ logger.info(f"Skipping {data_args.skip_steps - (epoch * (num_train_samples // batch_size_per_update))} steps...")
1534
+ # Gather the indices for creating the batch and do a training step
1535
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1), 1):
1536
+ cur_step = epoch * (num_train_samples // batch_size_per_update) + step
1537
+ if cur_step <= data_args.skip_steps:
1538
+ continue
1539
+
1540
+ samples = [vectorized_datasets[data_args.train_split_name][int(idx)] for idx in batch_idx]
1541
+ batch = data_collator(samples)
1542
+ batch = shard(batch.data)
1543
+ try:
1544
+ state, train_metric = p_train_step(state, batch)
1545
+ except TypeError as e:
1546
+ logger.warning("Encountered following error: \n", e)
1547
+
1548
+
1549
+ if cur_step % training_args.logging_steps == 0:
1550
+ # Save metrics
1551
+ train_metric = unreplicate(train_metric)
1552
+ train_time += time.time() - train_start
1553
+ # need to upcast all device arrays to fp32 for wandb logging (jnp.bfloat16 not supported) -> do this here OR in train_step
1554
+ write_wandb_log(to_fp32(train_metric), cur_step, prefix=data_args.train_split_name)
1555
+ # we won't log to tensorboard for now (it is fiddly logging param and grad norms on a layer-by-layer basis)
1556
+ # if has_tensorboard and jax.process_index() == 0:
1557
+ # write_train_metric(summary_writer, train_metrics, train_time, cur_step)
1558
+
1559
+ epochs.write(
1560
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']}, Gradient Norm: {train_metric['grad_norm']})"
1561
+ )
1562
+
1563
+ if cur_step % total_train_steps == 0:
1564
+ break
1565
+
1566
+ if training_args.eval_steps and cur_step % training_args.eval_steps == 0:
1567
+ run_evaluation(cur_step)
1568
+
1569
+ if cur_step % training_args.save_steps == 0:
1570
+ save_checkpoint(cur_step)
1571
+
1572
+ if training_args.eval_steps == 0 and (epoch + 1) != num_epochs:
1573
+ # run evaluation at the end of the epoch if eval steps are not specified
1574
+ run_evaluation(cur_step)
1575
+ save_checkpoint(cur_step)
1576
+
1577
+ if training_args.do_train:
1578
+ save_checkpoint(cur_step)
1579
+
1580
+ cur_step = max_steps if max_steps > 0 else cur_step # set step to max steps so that eval happens in alignment with training
1581
+
1582
+ if training_args.do_eval:
1583
+ run_evaluation(cur_step)
1584
+
1585
+ # TODO: collapse 'do_predict' into the run_evaluation function
1586
+ if training_args.do_predict:
1587
+ for split in [data_args.test_split_name]:
1588
+ # ======================== Evaluating ==============================
1589
+ eval_metrics = []
1590
+ eval_preds = []
1591
+ eval_labels = []
1592
+
1593
+ # Generate eval set by sequentially sampling indices from the test dataset and grouping by length
1594
+ eval_samples_idx = get_grouped_indices(vectorized_datasets[split], eval_batch_size)
1595
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
1596
+
1597
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc=f"Predicting {split}...", position=2)):
1598
+ samples = [vectorized_datasets[split][int(idx)] for idx in batch_idx]
1599
+ batch = data_collator(samples)
1600
+ labels = batch["labels"]
1601
+
1602
+ metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1603
+ eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1])))
1604
+ eval_metrics.append(metrics)
1605
+
1606
+ eval_labels.extend(labels)
1607
+
1608
+ # normalize eval metrics
1609
+ eval_metrics = get_metrics(eval_metrics)
1610
+ eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
1611
+ eval_metrics = to_fp32(eval_metrics)
1612
+
1613
+ # always run compute metrics
1614
+ error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels)
1615
+ eval_metrics.update(error_rate_metric)
1616
+ error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()])
1617
+
1618
+ # Print metrics and update progress bar
1619
+ desc = f"Step... ({cur_step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})"
1620
+ epochs.write(desc)
1621
+ epochs.desc = desc
1622
+
1623
+ # Save metrics
1624
+ write_wandb_log(eval_metrics, cur_step, prefix=split)
1625
+ write_wandb_pred(pred_str, label_str, cur_step, prefix=split)
1626
+ # if has_tensorboard and jax.process_index() == 0:
1627
+ # write_eval_metric(summary_writer, eval_metrics, cur_step, pred_str=pred_str)
1628
+
1629
+
1630
+ if __name__ == "__main__":
1631
+ main()
wandb/run-20220810_073735-23avj35z/files/config.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_version: 1
2
+
3
+ _wandb:
4
+ desc: null
5
+ value:
6
+ cli_version: 0.12.9
7
+ code_path: code/run_flax_speech_recognition_ctc.py
8
+ framework: huggingface
9
+ huggingface_version: 4.21.0
10
+ is_jupyter_run: false
11
+ is_kaggle_kernel: false
12
+ python_version: 3.8.10
13
+ start_time: 1660117055
14
+ t:
15
+ 1:
16
+ - 1
17
+ - 2
18
+ - 3
19
+ - 11
20
+ - 12
21
+ 2:
22
+ - 1
23
+ - 2
24
+ - 3
25
+ - 11
26
+ - 12
27
+ 3:
28
+ - 13
29
+ 4: 3.8.10
30
+ 5: 0.12.9
31
+ 6: 4.21.0
32
+ 8:
33
+ - 5
wandb/run-20220810_073735-23avj35z/files/diff.patch ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/wandb/debug-internal.log b/wandb/debug-internal.log
2
+ index 23926ef..9213b33 120000
3
+ --- a/wandb/debug-internal.log
4
+ +++ b/wandb/debug-internal.log
5
+ @@ -1 +1 @@
6
+ -run-20220805_230151-2y71vcu4/logs/debug-internal.log
7
+
8
+ +run-20220810_073735-23avj35z/logs/debug-internal.log
9
+
10
+ diff --git a/wandb/debug.log b/wandb/debug.log
11
+ index 279853d..bcac724 120000
12
+ --- a/wandb/debug.log
13
+ +++ b/wandb/debug.log
14
+ @@ -1 +1 @@
15
+ -run-20220805_230151-2y71vcu4/logs/debug.log
16
+
17
+ +run-20220810_073735-23avj35z/logs/debug.log
18
+
19
+ diff --git a/wandb/latest-run b/wandb/latest-run
20
+ index f069a7a..1406fac 120000
21
+ --- a/wandb/latest-run
22
+ +++ b/wandb/latest-run
23
+ @@ -1 +1 @@
24
+ -run-20220805_230151-2y71vcu4
25
+
26
+ +run-20220810_073735-23avj35z
27
+
wandb/run-20220810_073735-23avj35z/files/output.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b15353b15528bc042b1df6aa006abb62291e8c20dc7fa0bfe25bddcdf5307ef
3
+ size 166570
wandb/run-20220810_073735-23avj35z/files/requirements.txt ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.0.0
2
+ aiohttp==3.8.1
3
+ aiosignal==1.2.0
4
+ appdirs==1.4.4
5
+ astunparse==1.6.3
6
+ async-timeout==4.0.2
7
+ attrs==21.4.0
8
+ audioread==2.1.9
9
+ backcall==0.2.0
10
+ cachetools==4.2.4
11
+ certifi==2021.10.8
12
+ cffi==1.15.1
13
+ charset-normalizer==2.0.10
14
+ chex==0.1.3
15
+ click==8.0.3
16
+ cloud-tpu-client==0.10
17
+ cloud-tpu-profiler==2.4.0
18
+ clu==0.0.6
19
+ colorama==0.4.5
20
+ commonmark==0.9.1
21
+ configparser==5.2.0
22
+ contextlib2==21.6.0
23
+ cycler==0.11.0
24
+ datasets==2.4.0
25
+ decorator==5.1.0
26
+ dill==0.3.4
27
+ dm-tree==0.1.6
28
+ docker-pycreds==0.4.0
29
+ etils==0.6.0
30
+ exceptiongroup==1.0.0rc8
31
+ filelock==3.4.2
32
+ flatbuffers==2.0
33
+ flax==0.5.3
34
+ fonttools==4.28.5
35
+ frozenlist==1.2.0
36
+ fsspec==2021.11.1
37
+ future==0.18.2
38
+ gast==0.4.0
39
+ gitdb==4.0.9
40
+ gitpython==3.1.26
41
+ google-api-core==1.31.5
42
+ google-api-python-client==1.8.0
43
+ google-auth-httplib2==0.1.0
44
+ google-auth-oauthlib==0.4.6
45
+ google-auth==2.3.3
46
+ google-pasta==0.2.0
47
+ googleapis-common-protos==1.54.0
48
+ grpcio==1.43.0
49
+ h5py==3.6.0
50
+ httplib2==0.20.2
51
+ huggingface-hub==0.2.1
52
+ hypothesis==6.53.0
53
+ idna==3.3
54
+ importlib-metadata==4.10.0
55
+ importlib-resources==5.4.0
56
+ ipython==7.31.0
57
+ jax==0.3.15
58
+ jaxlib==0.3.15
59
+ jedi==0.18.1
60
+ jiwer==2.3.0
61
+ joblib==1.1.0
62
+ keras-preprocessing==1.1.2
63
+ keras==2.7.0
64
+ kiwisolver==1.3.2
65
+ libclang==12.0.0
66
+ librosa==0.9.2
67
+ libtpu-nightly==0.1.dev20220722
68
+ llvmlite==0.39.0
69
+ markdown==3.3.6
70
+ matplotlib-inline==0.1.3
71
+ matplotlib==3.5.1
72
+ ml-collections==0.1.0
73
+ msgpack==1.0.3
74
+ multidict==5.2.0
75
+ multiprocess==0.70.12.2
76
+ numba==0.56.0
77
+ numpy==1.22.0
78
+ oauth2client==4.1.3
79
+ oauthlib==3.1.1
80
+ opt-einsum==3.3.0
81
+ optax==0.1.3
82
+ packaging==21.3
83
+ pandas==1.3.5
84
+ parso==0.8.3
85
+ pathtools==0.1.2
86
+ pexpect==4.8.0
87
+ pickleshare==0.7.5
88
+ pillow==9.0.0
89
+ pip==22.2.1
90
+ pkg-resources==0.0.0
91
+ pooch==1.6.0
92
+ promise==2.3
93
+ prompt-toolkit==3.0.24
94
+ protobuf==3.19.1
95
+ psutil==5.9.0
96
+ ptyprocess==0.7.0
97
+ pyarrow==6.0.1
98
+ pyasn1-modules==0.2.8
99
+ pyasn1==0.4.8
100
+ pycparser==2.21
101
+ pyctcdecode==0.4.0
102
+ pygments==2.11.1
103
+ pygtrie==2.5.0
104
+ pyparsing==3.0.6
105
+ python-dateutil==2.8.2
106
+ python-levenshtein==0.12.2
107
+ pytz==2021.3
108
+ pyyaml==6.0
109
+ regex==2021.11.10
110
+ requests-oauthlib==1.3.0
111
+ requests==2.27.0
112
+ resampy==0.3.1
113
+ responses==0.18.0
114
+ rich==11.2.0
115
+ rsa==4.8
116
+ sacremoses==0.0.46
117
+ scikit-learn==1.1.1
118
+ scipy==1.7.3
119
+ sentry-sdk==1.5.2
120
+ setuptools==44.0.0
121
+ shortuuid==1.0.8
122
+ six==1.16.0
123
+ smmap==5.0.0
124
+ sortedcontainers==2.4.0
125
+ soundfile==0.10.3.post1
126
+ sox==1.4.1
127
+ subprocess32==3.5.4
128
+ tensorboard-data-server==0.6.1
129
+ tensorboard-plugin-wit==1.8.0
130
+ tensorboard==2.7.0
131
+ tensorflow-cpu==2.7.0
132
+ tensorflow-datasets==4.4.0
133
+ tensorflow-estimator==2.7.0
134
+ tensorflow-io-gcs-filesystem==0.23.1
135
+ tensorflow-metadata==1.5.0
136
+ tensorflow==2.7.0
137
+ tensorstore==0.1.21
138
+ termcolor==1.1.0
139
+ threadpoolctl==3.1.0
140
+ tokenizers==0.11.2
141
+ toolz==0.11.2
142
+ torch==1.12.0
143
+ torchaudio==0.12.0+cpu
144
+ tqdm==4.62.3
145
+ traitlets==5.1.1
146
+ transformers==4.21.0
147
+ typing-extensions==4.3.0
148
+ uritemplate==3.0.1
149
+ urllib3==1.26.7
150
+ wandb==0.12.9
151
+ wcwidth==0.2.5
152
+ werkzeug==2.0.2
153
+ wheel==0.37.1
154
+ wrapt==1.13.3
155
+ xxhash==2.0.2
156
+ yarl==1.7.2
157
+ yaspin==2.1.0
158
+ zipp==3.7.0
wandb/run-20220810_073735-23avj35z/files/wandb-metadata.json ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.4.0-1043-gcp-x86_64-with-glibc2.29",
3
+ "python": "3.8.10",
4
+ "heartbeatAt": "2022-08-10T07:37:39.012020",
5
+ "startedAt": "2022-08-10T07:37:35.560272",
6
+ "docker": null,
7
+ "cpu_count": 96,
8
+ "cuda": null,
9
+ "args": [
10
+ "--model_name_or_path=./",
11
+ "--hub_model_id=NbAiLab/wav2vec2-1b-npsc-nst-tpu",
12
+ "--tokenizer_name=./",
13
+ "--output_dir=./",
14
+ "--overwrite_output_dir",
15
+ "--num_train_epochs=40",
16
+ "--per_device_train_batch_size=2",
17
+ "--per_device_eval_batch_size=2",
18
+ "--gradient_accumulation_steps=1",
19
+ "--precision=full_mixed",
20
+ "--matmul_precision=bfloat16",
21
+ "--multisteps",
22
+ "--learning_rate=6.394633237505332e-05",
23
+ "--skip_steps=275000",
24
+ "--warmup_steps=2000",
25
+ "--length_column_name=input_length",
26
+ "--evaluation_strategy=steps",
27
+ "--text_column_name=text",
28
+ "--save_steps=5000",
29
+ "--eval_steps=5000",
30
+ "--logging_steps=100",
31
+ "--layerdrop=0.041",
32
+ "--attention_dropout=0.094",
33
+ "--activation_dropout=0.055",
34
+ "--hidden_dropout=0.047",
35
+ "--save_total_limit=5",
36
+ "--freeze_feature_encoder",
37
+ "--feat_proj_dropout=0.04",
38
+ "--mask_time_prob=0.082",
39
+ "--mask_time_length=10",
40
+ "--mask_feature_prob=0.25",
41
+ "--mask_feature_length=64",
42
+ "--gradient_checkpointing",
43
+ "--min_duration_in_seconds=0.5",
44
+ "--max_duration_in_seconds=30.0",
45
+ "--use_auth_token",
46
+ "--seed=42",
47
+ "--group_by_length",
48
+ "--do_train",
49
+ "--do_eval",
50
+ "--push_to_hub",
51
+ "--preprocessing_num_workers=32",
52
+ "--ctc_zero_infinity",
53
+ "--do_lower_case",
54
+ "--wandb_project=wav2vec2",
55
+ "--wandb_name=wav2vec2-1b-npsc-nst-tpu (cont.)",
56
+ "--remove_punctuation"
57
+ ],
58
+ "state": "running",
59
+ "program": "run_flax_speech_recognition_ctc.py",
60
+ "codePath": "run_flax_speech_recognition_ctc.py",
61
+ "git": {
62
+ "remote": "https://huggingface.co/NbAiLab/wav2vec2-1b-npsc-nst-tpu",
63
+ "commit": "f624ac4bfedfbf56891676d7c5f2e37b4c8e0745"
64
+ },
65
+ "email": "versae@gmail.com",
66
+ "root": "/data/wav2vec2-1b-npsc-nst-tpu",
67
+ "host": "t1v-n-eedfb410-w-0",
68
+ "username": "javierr",
69
+ "executable": "/data/flax/bin/python"
70
+ }
wandb/run-20220810_073735-23avj35z/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"train/grad_norm": 6.5625, "layer_grad_norm/": {"lm_head": {"bias": 0.031982421875, "kernel": 4.625}, "wav2vec2": {"encoder": {"layer_norm": {"bias": 0.0556640625, "scale": 0.06103515625}, "layers": {"0": {"attention": {"k_proj": {"bias": 5.9604644775390625e-05, "kernel": 0.023681640625}, "out_proj": {"bias": 0.04150390625, "kernel": 0.2431640625}, "q_proj": {"bias": 0.002899169921875, "kernel": 0.031005859375}, "v_proj": {"bias": 0.037109375, "kernel": 0.265625}}, "feed_forward": {"intermediate_dense": {"bias": 0.04443359375, "kernel": 0.515625}, "output_dense": {"bias": 0.0159912109375, "kernel": 0.439453125}}, "final_layer_norm": {"bias": 0.146484375, "scale": 0.322265625}, "layer_norm": {"bias": 0.0703125, "scale": 0.07080078125}}, "1": {"attention": {"k_proj": {"bias": 3.4332275390625e-05, "kernel": 0.03955078125}, "out_proj": {"bias": 0.0177001953125, "kernel": 0.134765625}, "q_proj": {"bias": 0.0035247802734375, "kernel": 0.0439453125}, "v_proj": {"bias": 0.02880859375, "kernel": 0.111328125}}, "feed_forward": {"intermediate_dense": {"bias": 0.022216796875, "kernel": 0.3359375}, "output_dense": {"bias": 0.0157470703125, "kernel": 0.259765625}}, "final_layer_norm": {"bias": 0.046142578125, "scale": 0.05712890625}, "layer_norm": {"bias": 0.05712890625, "scale": 0.039794921875}}, "10": {"attention": {"k_proj": {"bias": 3.600120544433594e-05, "kernel": 0.123046875}, "out_proj": {"bias": 0.01416015625, "kernel": 0.2001953125}, "q_proj": {"bias": 0.0078125, "kernel": 0.12255859375}, "v_proj": {"bias": 0.022705078125, "kernel": 0.2001953125}}, "feed_forward": {"intermediate_dense": {"bias": 0.021484375, "kernel": 0.328125}, "output_dense": {"bias": 0.013671875, "kernel": 0.2734375}}, "final_layer_norm": {"bias": 0.0390625, "scale": 0.0322265625}, "layer_norm": {"bias": 0.04931640625, "scale": 0.0341796875}}, "11": {"attention": {"k_proj": {"bias": 8.344650268554688e-05, "kernel": 0.158203125}, "out_proj": {"bias": 0.0142822265625, "kernel": 0.28125}, "q_proj": {"bias": 0.0087890625, "kernel": 0.130859375}, "v_proj": {"bias": 0.024658203125, "kernel": 0.28515625}}, "feed_forward": {"intermediate_dense": {"bias": 0.01953125, "kernel": 0.310546875}, "output_dense": {"bias": 0.013916015625, "kernel": 0.244140625}}, "final_layer_norm": {"bias": 0.03271484375, "scale": 0.0308837890625}, "layer_norm": {"bias": 0.05029296875, "scale": 0.0439453125}}, "12": {"attention": {"k_proj": {"bias": 4.982948303222656e-05, "kernel": 0.123046875}, "out_proj": {"bias": 0.0147705078125, "kernel": 0.244140625}, "q_proj": {"bias": 0.0081787109375, "kernel": 0.1162109375}, "v_proj": {"bias": 0.023681640625, "kernel": 0.2294921875}}, "feed_forward": {"intermediate_dense": {"bias": 0.021484375, "kernel": 0.32421875}, "output_dense": {"bias": 0.014404296875, "kernel": 0.255859375}}, "final_layer_norm": {"bias": 0.0390625, "scale": 0.04248046875}, "layer_norm": {"bias": 0.046630859375, "scale": 0.0546875}}, "13": {"attention": {"k_proj": {"bias": 0.00012493133544921875, "kernel": 0.15625}, "out_proj": {"bias": 0.01519775390625, "kernel": 0.330078125}, "q_proj": {"bias": 0.0111083984375, "kernel": 0.158203125}, "v_proj": {"bias": 0.026611328125, "kernel": 0.3046875}}, "feed_forward": {"intermediate_dense": {"bias": 0.022216796875, "kernel": 0.34375}, "output_dense": {"bias": 0.01513671875, "kernel": 0.3125}}, "final_layer_norm": {"bias": 0.040283203125, "scale": 0.032958984375}, "layer_norm": {"bias": 0.051513671875, "scale": 0.091796875}}, "14": {"attention": {"k_proj": {"bias": 9.441375732421875e-05, "kernel": 0.1005859375}, "out_proj": {"bias": 0.015625, "kernel": 0.2412109375}, "q_proj": {"bias": 0.006256103515625, "kernel": 0.099609375}, "v_proj": {"bias": 0.0235595703125, "kernel": 0.2275390625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0257568359375, "kernel": 0.39453125}, "output_dense": {"bias": 0.015380859375, "kernel": 0.33984375}}, "final_layer_norm": {"bias": 0.05126953125, "scale": 0.05517578125}, "layer_norm": {"bias": 0.041748046875, "scale": 0.03076171875}}, "15": {"attention": {"k_proj": {"bias": 0.0003070831298828125, "kernel": 0.1806640625}, "out_proj": {"bias": 0.015625, "kernel": 0.5078125}, "q_proj": {"bias": 0.0106201171875, "kernel": 0.173828125}, "v_proj": {"bias": 0.026611328125, "kernel": 0.361328125}}, "feed_forward": {"intermediate_dense": {"bias": 0.024169921875, "kernel": 0.376953125}, "output_dense": {"bias": 0.01556396484375, "kernel": 0.349609375}}, "final_layer_norm": {"bias": 0.041015625, "scale": 0.033447265625}, "layer_norm": {"bias": 0.048095703125, "scale": 0.072265625}}, "16": {"attention": {"k_proj": {"bias": 6.389617919921875e-05, "kernel": 0.1025390625}, "out_proj": {"bias": 0.016357421875, "kernel": 0.267578125}, "q_proj": {"bias": 0.0057373046875, "kernel": 0.1005859375}, "v_proj": {"bias": 0.0238037109375, "kernel": 0.220703125}}, "feed_forward": {"intermediate_dense": {"bias": 0.0223388671875, "kernel": 0.359375}, "output_dense": {"bias": 0.0159912109375, "kernel": 0.341796875}}, "final_layer_norm": {"bias": 0.0380859375, "scale": 0.033935546875}, "layer_norm": {"bias": 0.043212890625, "scale": 0.034912109375}}, "17": {"attention": {"k_proj": {"bias": 4.57763671875e-05, "kernel": 0.0927734375}, "out_proj": {"bias": 0.0172119140625, "kernel": 0.23046875}, "q_proj": {"bias": 0.005889892578125, "kernel": 0.087890625}, "v_proj": {"bias": 0.0244140625, "kernel": 0.2177734375}}, "feed_forward": {"intermediate_dense": {"bias": 0.024169921875, "kernel": 0.390625}, "output_dense": {"bias": 0.01708984375, "kernel": 0.353515625}}, "final_layer_norm": {"bias": 0.041259765625, "scale": 0.036376953125}, "layer_norm": {"bias": 0.0439453125, "scale": 0.0341796875}}, "18": {"attention": {"k_proj": {"bias": 0.000247955322265625, "kernel": 0.126953125}, "out_proj": {"bias": 0.017578125, "kernel": 0.369140625}, "q_proj": {"bias": 0.0076904296875, "kernel": 0.1337890625}, "v_proj": {"bias": 0.027587890625, "kernel": 0.298828125}}, "feed_forward": {"intermediate_dense": {"bias": 0.02587890625, "kernel": 0.44921875}, "output_dense": {"bias": 0.0174560546875, "kernel": 0.41015625}}, "final_layer_norm": {"bias": 0.04443359375, "scale": 0.03857421875}, "layer_norm": {"bias": 0.048583984375, "scale": 0.039794921875}}, "19": {"attention": {"k_proj": {"bias": 8.678436279296875e-05, "kernel": 0.140625}, "out_proj": {"bias": 0.017822265625, "kernel": 0.28125}, "q_proj": {"bias": 0.009033203125, "kernel": 0.140625}, "v_proj": {"bias": 0.0286865234375, "kernel": 0.283203125}}, "feed_forward": {"intermediate_dense": {"bias": 0.02587890625, "kernel": 0.474609375}, "output_dense": {"bias": 0.0174560546875, "kernel": 0.421875}}, "final_layer_norm": {"bias": 0.041748046875, "scale": 0.0380859375}, "layer_norm": {"bias": 0.052734375, "scale": 0.04052734375}}, "2": {"attention": {"k_proj": {"bias": 4.982948303222656e-05, "kernel": 0.07421875}, "out_proj": {"bias": 0.0177001953125, "kernel": 0.2060546875}, "q_proj": {"bias": 0.006195068359375, "kernel": 0.06982421875}, "v_proj": {"bias": 0.03173828125, "kernel": 0.181640625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0234375, "kernel": 0.390625}, "output_dense": {"bias": 0.01556396484375, "kernel": 0.275390625}}, "final_layer_norm": {"bias": 0.047119140625, "scale": 0.03173828125}, "layer_norm": {"bias": 0.0556640625, "scale": 0.07275390625}}, "20": {"attention": {"k_proj": {"bias": 2.110004425048828e-05, "kernel": 0.095703125}, "out_proj": {"bias": 0.0185546875, "kernel": 0.142578125}, "q_proj": {"bias": 0.005157470703125, "kernel": 0.0947265625}, "v_proj": {"bias": 0.0263671875, "kernel": 0.140625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0250244140625, "kernel": 0.4765625}, "output_dense": {"bias": 0.018310546875, "kernel": 0.390625}}, "final_layer_norm": {"bias": 0.041015625, "scale": 0.04443359375}, "layer_norm": {"bias": 0.046142578125, "scale": 0.038330078125}}, "21": {"attention": {"k_proj": {"bias": 4.00543212890625e-05, "kernel": 0.1259765625}, "out_proj": {"bias": 0.0189208984375, "kernel": 0.2216796875}, "q_proj": {"bias": 0.006927490234375, "kernel": 0.12890625}, "v_proj": {"bias": 0.02734375, "kernel": 0.203125}}, "feed_forward": {"intermediate_dense": {"bias": 0.0267333984375, "kernel": 0.51953125}, "output_dense": {"bias": 0.0185546875, "kernel": 0.41796875}}, "final_layer_norm": {"bias": 0.04541015625, "scale": 0.04736328125}, "layer_norm": {"bias": 0.044189453125, "scale": 0.054443359375}}, "22": {"attention": {"k_proj": {"bias": 3.3855438232421875e-05, "kernel": 0.1181640625}, "out_proj": {"bias": 0.019775390625, "kernel": 0.240234375}, "q_proj": {"bias": 0.006011962890625, "kernel": 0.11279296875}, "v_proj": {"bias": 0.028076171875, "kernel": 0.21875}}, "feed_forward": {"intermediate_dense": {"bias": 0.0269775390625, "kernel": 0.515625}, "output_dense": {"bias": 0.0194091796875, "kernel": 0.400390625}}, "final_layer_norm": {"bias": 0.046142578125, "scale": 0.047119140625}, "layer_norm": {"bias": 0.049560546875, "scale": 0.0458984375}}, "23": {"attention": {"k_proj": {"bias": 0.0001087188720703125, "kernel": 0.16015625}, "out_proj": {"bias": 0.0198974609375, "kernel": 0.443359375}, "q_proj": {"bias": 0.008544921875, "kernel": 0.1630859375}, "v_proj": {"bias": 0.03173828125, "kernel": 0.35546875}}, "feed_forward": {"intermediate_dense": {"bias": 0.0263671875, "kernel": 0.53125}, "output_dense": {"bias": 0.01953125, "kernel": 0.400390625}}, "final_layer_norm": {"bias": 0.044677734375, "scale": 0.04638671875}, "layer_norm": {"bias": 0.05615234375, "scale": 0.056396484375}}, "24": {"attention": {"k_proj": {"bias": 6.246566772460938e-05, "kernel": 0.169921875}, "out_proj": {"bias": 0.0191650390625, "kernel": 0.36328125}, "q_proj": {"bias": 0.00933837890625, "kernel": 0.18359375}, "v_proj": {"bias": 0.03271484375, "kernel": 0.328125}}, "feed_forward": {"intermediate_dense": {"bias": 0.02685546875, "kernel": 0.5390625}, "output_dense": {"bias": 0.01904296875, "kernel": 0.37890625}}, "final_layer_norm": {"bias": 0.04736328125, "scale": 0.04345703125}, "layer_norm": {"bias": 0.0625, "scale": 0.041015625}}, "25": {"attention": {"k_proj": {"bias": 6.079673767089844e-05, "kernel": 0.169921875}, "out_proj": {"bias": 0.0196533203125, "kernel": 0.3125}, "q_proj": {"bias": 0.00860595703125, "kernel": 0.16015625}, "v_proj": {"bias": 0.03271484375, "kernel": 0.32421875}}, "feed_forward": {"intermediate_dense": {"bias": 0.02734375, "kernel": 0.55859375}, "output_dense": {"bias": 0.01953125, "kernel": 0.375}}, "final_layer_norm": {"bias": 0.050537109375, "scale": 0.0478515625}, "layer_norm": {"bias": 0.06005859375, "scale": 0.06298828125}}, "26": {"attention": {"k_proj": {"bias": 5.9604644775390625e-05, "kernel": 0.169921875}, "out_proj": {"bias": 0.01953125, "kernel": 0.29296875}, "q_proj": {"bias": 0.01025390625, "kernel": 0.177734375}, "v_proj": {"bias": 0.0341796875, "kernel": 0.296875}}, "feed_forward": {"intermediate_dense": {"bias": 0.026611328125, "kernel": 0.51171875}, "output_dense": {"bias": 0.01904296875, "kernel": 0.353515625}}, "final_layer_norm": {"bias": 0.0478515625, "scale": 0.04443359375}, "layer_norm": {"bias": 0.060791015625, "scale": 0.06396484375}}, "27": {"attention": {"k_proj": {"bias": 0.00011396408081054688, "kernel": 0.2021484375}, "out_proj": {"bias": 0.01806640625, "kernel": 0.44921875}, "q_proj": {"bias": 0.01068115234375, "kernel": 0.2138671875}, "v_proj": {"bias": 0.03466796875, "kernel": 0.435546875}}, "feed_forward": {"intermediate_dense": {"bias": 0.02734375, "kernel": 0.515625}, "output_dense": {"bias": 0.0181884765625, "kernel": 0.36328125}}, "final_layer_norm": {"bias": 0.05078125, "scale": 0.045654296875}, "layer_norm": {"bias": 0.06640625, "scale": 0.04931640625}}, "28": {"attention": {"k_proj": {"bias": 0.0001049041748046875, "kernel": 0.20703125}, "out_proj": {"bias": 0.0164794921875, "kernel": 0.392578125}, "q_proj": {"bias": 0.01165771484375, "kernel": 0.208984375}, "v_proj": {"bias": 0.031494140625, "kernel": 0.404296875}}, "feed_forward": {"intermediate_dense": {"bias": 0.0234375, "kernel": 0.45703125}, "output_dense": {"bias": 0.016357421875, "kernel": 0.326171875}}, "final_layer_norm": {"bias": 0.04248046875, "scale": 0.044921875}, "layer_norm": {"bias": 0.0673828125, "scale": 0.08447265625}}, "29": {"attention": {"k_proj": {"bias": 9.918212890625e-05, "kernel": 0.267578125}, "out_proj": {"bias": 0.0157470703125, "kernel": 0.28515625}, "q_proj": {"bias": 0.01495361328125, "kernel": 0.265625}, "v_proj": {"bias": 0.02978515625, "kernel": 0.31640625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0234375, "kernel": 0.494140625}, "output_dense": {"bias": 0.01531982421875, "kernel": 0.296875}}, "final_layer_norm": {"bias": 0.03955078125, "scale": 0.03515625}, "layer_norm": {"bias": 0.0654296875, "scale": 0.061279296875}}, "3": {"attention": {"k_proj": {"bias": 0.00012111663818359375, "kernel": 0.0986328125}, "out_proj": {"bias": 0.016845703125, "kernel": 0.314453125}, "q_proj": {"bias": 0.00726318359375, "kernel": 0.0888671875}, "v_proj": {"bias": 0.0283203125, "kernel": 0.2470703125}}, "feed_forward": {"intermediate_dense": {"bias": 0.0242919921875, "kernel": 0.3828125}, "output_dense": {"bias": 0.0150146484375, "kernel": 0.271484375}}, "final_layer_norm": {"bias": 0.0458984375, "scale": 0.03125}, "layer_norm": {"bias": 0.0498046875, "scale": 0.0380859375}}, "30": {"attention": {"k_proj": {"bias": 0.0001220703125, "kernel": 0.13671875}, "out_proj": {"bias": 0.0152587890625, "kernel": 0.328125}, "q_proj": {"bias": 0.006378173828125, "kernel": 0.138671875}, "v_proj": {"bias": 0.029296875, "kernel": 0.3671875}}, "feed_forward": {"intermediate_dense": {"bias": 0.023681640625, "kernel": 0.51953125}, "output_dense": {"bias": 0.01446533203125, "kernel": 0.26171875}}, "final_layer_norm": {"bias": 0.04150390625, "scale": 0.03564453125}, "layer_norm": {"bias": 0.04931640625, "scale": 0.037109375}}, "31": {"attention": {"k_proj": {"bias": 0.00010347366333007812, "kernel": 0.14453125}, "out_proj": {"bias": 0.0140380859375, "kernel": 0.29296875}, "q_proj": {"bias": 0.006378173828125, "kernel": 0.134765625}, "v_proj": {"bias": 0.0238037109375, "kernel": 0.314453125}}, "feed_forward": {"intermediate_dense": {"bias": 0.02392578125, "kernel": 0.51953125}, "output_dense": {"bias": 0.01385498046875, "kernel": 0.2578125}}, "final_layer_norm": {"bias": 0.0390625, "scale": 0.03662109375}, "layer_norm": {"bias": 0.039306640625, "scale": 0.0291748046875}}, "32": {"attention": {"k_proj": {"bias": 8.296966552734375e-05, "kernel": 0.15625}, "out_proj": {"bias": 0.01263427734375, "kernel": 0.28125}, "q_proj": {"bias": 0.0079345703125, "kernel": 0.1533203125}, "v_proj": {"bias": 0.0264892578125, "kernel": 0.4921875}}, "feed_forward": {"intermediate_dense": {"bias": 0.0216064453125, "kernel": 0.431640625}, "output_dense": {"bias": 0.01129150390625, "kernel": 0.212890625}}, "final_layer_norm": {"bias": 0.04150390625, "scale": 0.03271484375}, "layer_norm": {"bias": 0.046630859375, "scale": 0.05419921875}}, "33": {"attention": {"k_proj": {"bias": 5.9604644775390625e-05, "kernel": 0.166015625}, "out_proj": {"bias": 0.01092529296875, "kernel": 0.2275390625}, "q_proj": {"bias": 0.008544921875, "kernel": 0.166015625}, "v_proj": {"bias": 0.023193359375, "kernel": 0.34765625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0196533203125, "kernel": 0.390625}, "output_dense": {"bias": 0.00897216796875, "kernel": 0.1875}}, "final_layer_norm": {"bias": 0.04345703125, "scale": 0.0361328125}, "layer_norm": {"bias": 0.039794921875, "scale": 0.0498046875}}, "34": {"attention": {"k_proj": {"bias": 0.0002346038818359375, "kernel": 0.158203125}, "out_proj": {"bias": 0.0081787109375, "kernel": 0.181640625}, "q_proj": {"bias": 0.006927490234375, "kernel": 0.14453125}, "v_proj": {"bias": 0.0177001953125, "kernel": 0.25390625}}, "feed_forward": {"intermediate_dense": {"bias": 0.01434326171875, "kernel": 0.291015625}, "output_dense": {"bias": 0.0072021484375, "kernel": 0.1748046875}}, "final_layer_norm": {"bias": 0.028076171875, "scale": 0.025146484375}, "layer_norm": {"bias": 0.03369140625, "scale": 0.026611328125}}, "35": {"attention": {"k_proj": {"bias": 0.0001506805419921875, "kernel": 0.10791015625}, "out_proj": {"bias": 0.00640869140625, "kernel": 0.2109375}, "q_proj": {"bias": 0.004852294921875, "kernel": 0.10791015625}, "v_proj": {"bias": 0.01177978515625, "kernel": 0.21484375}}, "feed_forward": {"intermediate_dense": {"bias": 0.010498046875, "kernel": 0.2119140625}, "output_dense": {"bias": 0.005889892578125, "kernel": 0.15234375}}, "final_layer_norm": {"bias": 0.0206298828125, "scale": 0.0220947265625}, "layer_norm": {"bias": 0.024169921875, "scale": 0.02880859375}}, "36": {"attention": {"k_proj": {"bias": 4.410743713378906e-05, "kernel": 0.1005859375}, "out_proj": {"bias": 0.005645751953125, "kernel": 0.1552734375}, "q_proj": {"bias": 0.00445556640625, "kernel": 0.095703125}, "v_proj": {"bias": 0.00946044921875, "kernel": 0.14453125}}, "feed_forward": {"intermediate_dense": {"bias": 0.0089111328125, "kernel": 0.177734375}, "output_dense": {"bias": 0.0050048828125, "kernel": 0.111328125}}, "final_layer_norm": {"bias": 0.017578125, "scale": 0.01513671875}, "layer_norm": {"bias": 0.0191650390625, "scale": 0.01806640625}}, "37": {"attention": {"k_proj": {"bias": 9.441375732421875e-05, "kernel": 0.0849609375}, "out_proj": {"bias": 0.004913330078125, "kernel": 0.11474609375}, "q_proj": {"bias": 0.00390625, "kernel": 0.0830078125}, "v_proj": {"bias": 0.00897216796875, "kernel": 0.1318359375}}, "feed_forward": {"intermediate_dense": {"bias": 0.00823974609375, "kernel": 0.16796875}, "output_dense": {"bias": 0.004241943359375, "kernel": 0.09716796875}}, "final_layer_norm": {"bias": 0.015869140625, "scale": 0.01434326171875}, "layer_norm": {"bias": 0.019287109375, "scale": 0.015869140625}}, "38": {"attention": {"k_proj": {"bias": 5.650520324707031e-05, "kernel": 0.09130859375}, "out_proj": {"bias": 0.0040283203125, "kernel": 0.11865234375}, "q_proj": {"bias": 0.00396728515625, "kernel": 0.08642578125}, "v_proj": {"bias": 0.007354736328125, "kernel": 0.1279296875}}, "feed_forward": {"intermediate_dense": {"bias": 0.0072021484375, "kernel": 0.150390625}, "output_dense": {"bias": 0.0034637451171875, "kernel": 0.09423828125}}, "final_layer_norm": {"bias": 0.0152587890625, "scale": 0.0146484375}, "layer_norm": {"bias": 0.0162353515625, "scale": 0.0135498046875}}, "39": {"attention": {"k_proj": {"bias": 5.316734313964844e-05, "kernel": 0.09619140625}, "out_proj": {"bias": 0.0030975341796875, "kernel": 0.09619140625}, "q_proj": {"bias": 0.00408935546875, "kernel": 0.0908203125}, "v_proj": {"bias": 0.006011962890625, "kernel": 0.10986328125}}, "feed_forward": {"intermediate_dense": {"bias": 0.005401611328125, "kernel": 0.12109375}, "output_dense": {"bias": 0.0025634765625, "kernel": 0.08642578125}}, "final_layer_norm": {"bias": 0.01202392578125, "scale": 0.01226806640625}, "layer_norm": {"bias": 0.0150146484375, "scale": 0.01556396484375}}, "4": {"attention": {"k_proj": {"bias": 0.000148773193359375, "kernel": 0.10498046875}, "out_proj": {"bias": 0.015869140625, "kernel": 0.361328125}, "q_proj": {"bias": 0.0072021484375, "kernel": 0.1005859375}, "v_proj": {"bias": 0.026123046875, "kernel": 0.3046875}}, "feed_forward": {"intermediate_dense": {"bias": 0.023193359375, "kernel": 0.36328125}, "output_dense": {"bias": 0.014404296875, "kernel": 0.29296875}}, "final_layer_norm": {"bias": 0.042724609375, "scale": 0.034423828125}, "layer_norm": {"bias": 0.0478515625, "scale": 0.060546875}}, "40": {"attention": {"k_proj": {"bias": 5.269050598144531e-05, "kernel": 0.046875}, "out_proj": {"bias": 0.0025787353515625, "kernel": 0.080078125}, "q_proj": {"bias": 0.0020294189453125, "kernel": 0.0458984375}, "v_proj": {"bias": 0.004150390625, "kernel": 0.07080078125}}, "feed_forward": {"intermediate_dense": {"bias": 0.004302978515625, "kernel": 0.09326171875}, "output_dense": {"bias": 0.0023040771484375, "kernel": 0.060791015625}}, "final_layer_norm": {"bias": 0.0087890625, "scale": 0.011474609375}, "layer_norm": {"bias": 0.00823974609375, "scale": 0.007781982421875}}, "41": {"attention": {"k_proj": {"bias": 4.0531158447265625e-05, "kernel": 0.0673828125}, "out_proj": {"bias": 0.002044677734375, "kernel": 0.087890625}, "q_proj": {"bias": 0.0025787353515625, "kernel": 0.0634765625}, "v_proj": {"bias": 0.00439453125, "kernel": 0.1044921875}}, "feed_forward": {"intermediate_dense": {"bias": 0.0035858154296875, "kernel": 0.091796875}, "output_dense": {"bias": 0.00168609619140625, "kernel": 0.064453125}}, "final_layer_norm": {"bias": 0.00921630859375, "scale": 0.0106201171875}, "layer_norm": {"bias": 0.010986328125, "scale": 0.0128173828125}}, "42": {"attention": {"k_proj": {"bias": 1.1801719665527344e-05, "kernel": 0.02099609375}, "out_proj": {"bias": 0.001678466796875, "kernel": 0.048828125}, "q_proj": {"bias": 0.0009307861328125, "kernel": 0.02197265625}, "v_proj": {"bias": 0.002349853515625, "kernel": 0.0478515625}}, "feed_forward": {"intermediate_dense": {"bias": 0.002685546875, "kernel": 0.07421875}, "output_dense": {"bias": 0.0014801025390625, "kernel": 0.053955078125}}, "final_layer_norm": {"bias": 0.0059814453125, "scale": 0.00885009765625}, "layer_norm": {"bias": 0.0045166015625, "scale": 0.00439453125}}, "43": {"attention": {"k_proj": {"bias": 8.046627044677734e-06, "kernel": 0.01806640625}, "out_proj": {"bias": 0.0015106201171875, "kernel": 0.035888671875}, "q_proj": {"bias": 0.00095367431640625, "kernel": 0.0208740234375}, "v_proj": {"bias": 0.00171661376953125, "kernel": 0.031005859375}}, "feed_forward": {"intermediate_dense": {"bias": 0.002777099609375, "kernel": 0.083984375}, "output_dense": {"bias": 0.00125885009765625, "kernel": 0.052734375}}, "final_layer_norm": {"bias": 0.00689697265625, "scale": 0.007110595703125}, "layer_norm": {"bias": 0.0036468505859375, "scale": 0.005706787109375}}, "44": {"attention": {"k_proj": {"bias": 1.3113021850585938e-05, "kernel": 0.023681640625}, "out_proj": {"bias": 0.001312255859375, "kernel": 0.033447265625}, "q_proj": {"bias": 0.000946044921875, "kernel": 0.021484375}, "v_proj": {"bias": 0.0017547607421875, "kernel": 0.03515625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0023956298828125, "kernel": 0.0771484375}, "output_dense": {"bias": 0.001129150390625, "kernel": 0.052001953125}}, "final_layer_norm": {"bias": 0.005706787109375, "scale": 0.005859375}, "layer_norm": {"bias": 0.0042724609375, "scale": 0.004730224609375}}, "45": {"attention": {"k_proj": {"bias": 1.4424324035644531e-05, "kernel": 0.01239013671875}, "out_proj": {"bias": 0.00110626220703125, "kernel": 0.0267333984375}, "q_proj": {"bias": 0.00136566162109375, "kernel": 0.029541015625}, "v_proj": {"bias": 0.00142669677734375, "kernel": 0.027587890625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0019989013671875, "kernel": 0.06396484375}, "output_dense": {"bias": 0.000881195068359375, "kernel": 0.05615234375}}, "final_layer_norm": {"bias": 0.006072998046875, "scale": 0.006591796875}, "layer_norm": {"bias": 0.004791259765625, "scale": 0.00494384765625}}, "46": {"attention": {"k_proj": {"bias": 5.745887756347656e-05, "kernel": 0.006439208984375}, "out_proj": {"bias": 0.000888824462890625, "kernel": 0.0289306640625}, "q_proj": {"bias": 0.000591278076171875, "kernel": 0.011962890625}, "v_proj": {"bias": 0.0010986328125, "kernel": 0.0233154296875}}, "feed_forward": {"intermediate_dense": {"bias": 0.00141143798828125, "kernel": 0.03955078125}, "output_dense": {"bias": 0.000873565673828125, "kernel": 0.0478515625}}, "final_layer_norm": {"bias": 0.00433349609375, "scale": 0.00433349609375}, "layer_norm": {"bias": 0.0036468505859375, "scale": 0.003814697265625}}, "47": {"attention": {"k_proj": {"bias": 0.00011301040649414062, "kernel": 0.003997802734375}, "out_proj": {"bias": 0.000896453857421875, "kernel": 0.06640625}, "q_proj": {"bias": 0.00014591217041015625, "kernel": 0.00286865234375}, "v_proj": {"bias": 0.00118255615234375, "kernel": 0.0230712890625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0010528564453125, "kernel": 0.0252685546875}, "output_dense": {"bias": 0.000881195068359375, "kernel": 0.1787109375}}, "final_layer_norm": {"bias": 0.005950927734375, "scale": 0.00677490234375}, "layer_norm": {"bias": 0.005859375, "scale": 0.005767822265625}}, "5": {"attention": {"k_proj": {"bias": 6.4849853515625e-05, "kernel": 0.123046875}, "out_proj": {"bias": 0.0159912109375, "kernel": 0.203125}, "q_proj": {"bias": 0.007598876953125, "kernel": 0.12158203125}, "v_proj": {"bias": 0.02685546875, "kernel": 0.1953125}}, "feed_forward": {"intermediate_dense": {"bias": 0.023193359375, "kernel": 0.33984375}, "output_dense": {"bias": 0.0147705078125, "kernel": 0.275390625}}, "final_layer_norm": {"bias": 0.041015625, "scale": 0.037841796875}, "layer_norm": {"bias": 0.052734375, "scale": 0.04736328125}}, "6": {"attention": {"k_proj": {"bias": 7.152557373046875e-05, "kernel": 0.1318359375}, "out_proj": {"bias": 0.0152587890625, "kernel": 0.349609375}, "q_proj": {"bias": 0.00823974609375, "kernel": 0.119140625}, "v_proj": {"bias": 0.02685546875, "kernel": 0.31640625}}, "feed_forward": {"intermediate_dense": {"bias": 0.022216796875, "kernel": 0.341796875}, "output_dense": {"bias": 0.014404296875, "kernel": 0.271484375}}, "final_layer_norm": {"bias": 0.0380859375, "scale": 0.0322265625}, "layer_norm": {"bias": 0.049560546875, "scale": 0.058349609375}}, "7": {"attention": {"k_proj": {"bias": 7.62939453125e-05, "kernel": 0.1328125}, "out_proj": {"bias": 0.0150146484375, "kernel": 0.349609375}, "q_proj": {"bias": 0.00921630859375, "kernel": 0.126953125}, "v_proj": {"bias": 0.0255126953125, "kernel": 0.30859375}}, "feed_forward": {"intermediate_dense": {"bias": 0.021484375, "kernel": 0.33984375}, "output_dense": {"bias": 0.01409912109375, "kernel": 0.26171875}}, "final_layer_norm": {"bias": 0.038818359375, "scale": 0.033935546875}, "layer_norm": {"bias": 0.05126953125, "scale": 0.050048828125}}, "8": {"attention": {"k_proj": {"bias": 7.915496826171875e-05, "kernel": 0.12060546875}, "out_proj": {"bias": 0.01507568359375, "kernel": 0.302734375}, "q_proj": {"bias": 0.007568359375, "kernel": 0.11474609375}, "v_proj": {"bias": 0.0262451171875, "kernel": 0.27734375}}, "feed_forward": {"intermediate_dense": {"bias": 0.023193359375, "kernel": 0.361328125}, "output_dense": {"bias": 0.01409912109375, "kernel": 0.275390625}}, "final_layer_norm": {"bias": 0.044677734375, "scale": 0.033447265625}, "layer_norm": {"bias": 0.0498046875, "scale": 0.06103515625}}, "9": {"attention": {"k_proj": {"bias": 0.00011777877807617188, "kernel": 0.1513671875}, "out_proj": {"bias": 0.0140380859375, "kernel": 0.416015625}, "q_proj": {"bias": 0.00830078125, "kernel": 0.1376953125}, "v_proj": {"bias": 0.0238037109375, "kernel": 0.40234375}}, "feed_forward": {"intermediate_dense": {"bias": 0.0213623046875, "kernel": 0.3515625}, "output_dense": {"bias": 0.0135498046875, "kernel": 0.279296875}}, "final_layer_norm": {"bias": 0.037109375, "scale": 0.0322265625}, "layer_norm": {"bias": 0.04443359375, "scale": 0.04833984375}}}, "pos_conv_embed": {"conv": {"bias": 0.034912109375, "weight_g": 0.044189453125, "weight_v": 0.287109375}}}, "feature_extractor": {"conv_layers": {"0": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "1": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "2": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "3": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "4": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "5": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "6": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}}}, "feature_projection": {"layer_norm": {"bias": 0.1337890625, "scale": 0.1611328125}, "projection": {"bias": 0.0556640625, "kernel": 1.0546875}}, "masked_spec_embed": 0.0}}, "layer_param_norm/": {"lm_head": {"bias": 0.7824921607971191, "kernel": 55.72966766357422}, "wav2vec2": {"encoder": {"layer_norm": {"bias": 59.6768798828125, "scale": 74.17054748535156}, "layers": {"0": {"attention": {"k_proj": {"bias": 0.37033993005752563, "kernel": 27.536663055419922}, "out_proj": {"bias": 1.6469175815582275, "kernel": 26.147050857543945}, "q_proj": {"bias": 1.5330281257629395, "kernel": 27.813282012939453}, "v_proj": {"bias": 0.44783300161361694, "kernel": 26.55841064453125}}, "feed_forward": {"intermediate_dense": {"bias": 1.9835489988327026, "kernel": 100.66567993164062}, "output_dense": {"bias": 1.116748571395874, "kernel": 96.76679992675781}}, "final_layer_norm": {"bias": 1.335214376449585, "scale": 19.85782241821289}, "layer_norm": {"bias": 2.923041343688965, "scale": 15.398418426513672}}, "1": {"attention": {"k_proj": {"bias": 0.3767347037792206, "kernel": 41.013240814208984}, "out_proj": {"bias": 1.3653483390808105, "kernel": 43.371070861816406}, "q_proj": {"bias": 3.0925614833831787, "kernel": 41.05661392211914}, "v_proj": {"bias": 0.2924947738647461, "kernel": 41.61189270019531}}, "feed_forward": {"intermediate_dense": {"bias": 1.9882662296295166, "kernel": 98.7442626953125}, "output_dense": {"bias": 0.8527815341949463, "kernel": 87.83541870117188}}, "final_layer_norm": {"bias": 1.3605934381484985, "scale": 19.084806442260742}, "layer_norm": {"bias": 1.9318764209747314, "scale": 17.761367797851562}}, "10": {"attention": {"k_proj": {"bias": 0.4123449921607971, "kernel": 49.44670486450195}, "out_proj": {"bias": 1.3130683898925781, "kernel": 52.27025604248047}, "q_proj": {"bias": 2.48445200920105, "kernel": 49.528873443603516}, "v_proj": {"bias": 0.3416975736618042, "kernel": 52.344085693359375}}, "feed_forward": {"intermediate_dense": {"bias": 1.974550485610962, "kernel": 102.70410919189453}, "output_dense": {"bias": 0.5955485105514526, "kernel": 95.81275939941406}}, "final_layer_norm": {"bias": 2.3762400150299072, "scale": 20.81279754638672}, "layer_norm": {"bias": 1.806241512298584, "scale": 21.429487228393555}}, "11": {"attention": {"k_proj": {"bias": 0.4460787773132324, "kernel": 49.37338638305664}, "out_proj": {"bias": 1.1512463092803955, "kernel": 51.949554443359375}, "q_proj": {"bias": 2.5446064472198486, "kernel": 49.20353698730469}, "v_proj": {"bias": 0.40995872020721436, "kernel": 52.2607536315918}}, "feed_forward": {"intermediate_dense": {"bias": 2.019528388977051, "kernel": 103.56025695800781}, "output_dense": {"bias": 0.5711302757263184, "kernel": 97.53776550292969}}, "final_layer_norm": {"bias": 2.3660812377929688, "scale": 20.919677734375}, "layer_norm": {"bias": 1.7802445888519287, "scale": 22.01519203186035}}, "12": {"attention": {"k_proj": {"bias": 0.4312320351600647, "kernel": 50.07032775878906}, "out_proj": {"bias": 1.126122236251831, "kernel": 51.988826751708984}, "q_proj": {"bias": 2.4090728759765625, "kernel": 49.92729949951172}, "v_proj": {"bias": 0.4024103879928589, "kernel": 52.296756744384766}}, "feed_forward": {"intermediate_dense": {"bias": 2.0548508167266846, "kernel": 104.54823303222656}, "output_dense": {"bias": 0.5548778772354126, "kernel": 99.2693099975586}}, "final_layer_norm": {"bias": 2.2933573722839355, "scale": 20.85626983642578}, "layer_norm": {"bias": 1.8587299585342407, "scale": 22.473487854003906}}, "13": {"attention": {"k_proj": {"bias": 0.4430793821811676, "kernel": 51.76431655883789}, "out_proj": {"bias": 1.1271920204162598, "kernel": 51.86852264404297}, "q_proj": {"bias": 2.359200954437256, "kernel": 51.75225830078125}, "v_proj": {"bias": 0.3906242251396179, "kernel": 51.92781066894531}}, "feed_forward": {"intermediate_dense": {"bias": 2.0941619873046875, "kernel": 105.30806732177734}, "output_dense": {"bias": 0.5719542503356934, "kernel": 99.8712158203125}}, "final_layer_norm": {"bias": 2.2314066886901855, "scale": 21.027400970458984}, "layer_norm": {"bias": 1.9997800588607788, "scale": 22.84510040283203}}, "14": {"attention": {"k_proj": {"bias": 0.43604815006256104, "kernel": 51.92181396484375}, "out_proj": {"bias": 1.2681217193603516, "kernel": 49.762760162353516}, "q_proj": {"bias": 2.4942922592163086, "kernel": 52.049808502197266}, "v_proj": {"bias": 0.36820662021636963, "kernel": 49.26283264160156}}, "feed_forward": {"intermediate_dense": {"bias": 2.1335830688476562, "kernel": 105.94457244873047}, "output_dense": {"bias": 0.6063626408576965, "kernel": 101.24859619140625}}, "final_layer_norm": {"bias": 2.2754664421081543, "scale": 21.145992279052734}, "layer_norm": {"bias": 2.1295526027679443, "scale": 22.584672927856445}}, "15": {"attention": {"k_proj": {"bias": 0.45931702852249146, "kernel": 51.93058776855469}, "out_proj": {"bias": 1.3753983974456787, "kernel": 50.94449234008789}, "q_proj": {"bias": 2.5871663093566895, "kernel": 52.099769592285156}, "v_proj": {"bias": 0.459650456905365, "kernel": 50.5831298828125}}, "feed_forward": {"intermediate_dense": {"bias": 2.132938861846924, "kernel": 105.57443237304688}, "output_dense": {"bias": 0.7701732516288757, "kernel": 101.94094848632812}}, "final_layer_norm": {"bias": 2.327320098876953, "scale": 21.192947387695312}, "layer_norm": {"bias": 2.4148712158203125, "scale": 23.526634216308594}}, "16": {"attention": {"k_proj": {"bias": 0.4008745551109314, "kernel": 51.772621154785156}, "out_proj": {"bias": 1.27531099319458, "kernel": 50.134521484375}, "q_proj": {"bias": 2.667466163635254, "kernel": 51.75814437866211}, "v_proj": {"bias": 0.3768249750137329, "kernel": 49.78093719482422}}, "feed_forward": {"intermediate_dense": {"bias": 2.1094985008239746, "kernel": 106.10227966308594}, "output_dense": {"bias": 0.7860437631607056, "kernel": 102.66590881347656}}, "final_layer_norm": {"bias": 2.337951421737671, "scale": 21.583194732666016}, "layer_norm": {"bias": 2.283249855041504, "scale": 22.168060302734375}}, "17": {"attention": {"k_proj": {"bias": 0.3966267704963684, "kernel": 51.728885650634766}, "out_proj": {"bias": 1.2151354551315308, "kernel": 49.4556884765625}, "q_proj": {"bias": 2.714320182800293, "kernel": 51.81880187988281}, "v_proj": {"bias": 0.42661017179489136, "kernel": 49.11927032470703}}, "feed_forward": {"intermediate_dense": {"bias": 2.1101765632629395, "kernel": 107.14872741699219}, "output_dense": {"bias": 0.8218655586242676, "kernel": 103.06423950195312}}, "final_layer_norm": {"bias": 2.383938789367676, "scale": 22.070323944091797}, "layer_norm": {"bias": 2.222898483276367, "scale": 21.219982147216797}}, "18": {"attention": {"k_proj": {"bias": 0.4409676194190979, "kernel": 52.41611099243164}, "out_proj": {"bias": 1.3447906970977783, "kernel": 50.491905212402344}, "q_proj": {"bias": 2.614685535430908, "kernel": 52.796600341796875}, "v_proj": {"bias": 0.4518332779407501, "kernel": 50.001895904541016}}, "feed_forward": {"intermediate_dense": {"bias": 2.144195556640625, "kernel": 107.41338348388672}, "output_dense": {"bias": 0.9481453895568848, "kernel": 104.72514343261719}}, "final_layer_norm": {"bias": 2.5390124320983887, "scale": 22.15178680419922}, "layer_norm": {"bias": 2.424910068511963, "scale": 23.585906982421875}}, "19": {"attention": {"k_proj": {"bias": 0.38193291425704956, "kernel": 51.511146545410156}, "out_proj": {"bias": 1.3303101062774658, "kernel": 50.10035705566406}, "q_proj": {"bias": 2.930327892303467, "kernel": 51.865638732910156}, "v_proj": {"bias": 0.4086824655532837, "kernel": 49.38078308105469}}, "feed_forward": {"intermediate_dense": {"bias": 2.1912901401519775, "kernel": 107.95254516601562}, "output_dense": {"bias": 1.0248571634292603, "kernel": 105.65098571777344}}, "final_layer_norm": {"bias": 2.4923481941223145, "scale": 22.505674362182617}, "layer_norm": {"bias": 2.2888314723968506, "scale": 22.31826400756836}}, "2": {"attention": {"k_proj": {"bias": 0.454792320728302, "kernel": 47.77275085449219}, "out_proj": {"bias": 1.256988525390625, "kernel": 45.969764709472656}, "q_proj": {"bias": 3.2510807514190674, "kernel": 47.61664581298828}, "v_proj": {"bias": 0.339598685503006, "kernel": 45.72273254394531}}, "feed_forward": {"intermediate_dense": {"bias": 1.9737317562103271, "kernel": 103.32754516601562}, "output_dense": {"bias": 0.7398276329040527, "kernel": 91.11263275146484}}, "final_layer_norm": {"bias": 1.5421981811523438, "scale": 21.561111450195312}, "layer_norm": {"bias": 1.7081801891326904, "scale": 20.852447509765625}}, "20": {"attention": {"k_proj": {"bias": 0.4067543148994446, "kernel": 51.605438232421875}, "out_proj": {"bias": 1.359946370124817, "kernel": 49.45553207397461}, "q_proj": {"bias": 2.8498687744140625, "kernel": 52.224571228027344}, "v_proj": {"bias": 0.36227869987487793, "kernel": 48.43864822387695}}, "feed_forward": {"intermediate_dense": {"bias": 2.1725549697875977, "kernel": 109.17405700683594}, "output_dense": {"bias": 1.1388803720474243, "kernel": 106.40528106689453}}, "final_layer_norm": {"bias": 2.435314655303955, "scale": 23.4317626953125}, "layer_norm": {"bias": 2.231672525405884, "scale": 22.230525970458984}}, "21": {"attention": {"k_proj": {"bias": 0.4161534905433655, "kernel": 51.942527770996094}, "out_proj": {"bias": 1.403618335723877, "kernel": 49.51059341430664}, "q_proj": {"bias": 2.7690629959106445, "kernel": 52.67078399658203}, "v_proj": {"bias": 0.41060006618499756, "kernel": 48.64883041381836}}, "feed_forward": {"intermediate_dense": {"bias": 2.2174296379089355, "kernel": 109.5155029296875}, "output_dense": {"bias": 1.253208041191101, "kernel": 106.88243865966797}}, "final_layer_norm": {"bias": 2.4632763862609863, "scale": 23.175764083862305}, "layer_norm": {"bias": 2.2785892486572266, "scale": 22.234222412109375}}, "22": {"attention": {"k_proj": {"bias": 0.45357397198677063, "kernel": 52.54576110839844}, "out_proj": {"bias": 1.349219560623169, "kernel": 49.533172607421875}, "q_proj": {"bias": 2.8105549812316895, "kernel": 52.86981201171875}, "v_proj": {"bias": 0.3973655700683594, "kernel": 49.33363342285156}}, "feed_forward": {"intermediate_dense": {"bias": 2.1619315147399902, "kernel": 109.95498657226562}, "output_dense": {"bias": 1.3076066970825195, "kernel": 106.3852310180664}}, "final_layer_norm": {"bias": 2.3642821311950684, "scale": 22.684059143066406}, "layer_norm": {"bias": 2.3316237926483154, "scale": 21.545879364013672}}, "23": {"attention": {"k_proj": {"bias": 0.4928613007068634, "kernel": 53.47669219970703}, "out_proj": {"bias": 1.564335823059082, "kernel": 50.98707580566406}, "q_proj": {"bias": 2.7065773010253906, "kernel": 53.582611083984375}, "v_proj": {"bias": 0.5810648202896118, "kernel": 51.54853820800781}}, "feed_forward": {"intermediate_dense": {"bias": 2.131969690322876, "kernel": 109.86410522460938}, "output_dense": {"bias": 1.2769315242767334, "kernel": 107.37890625}}, "final_layer_norm": {"bias": 2.767916679382324, "scale": 22.887813568115234}, "layer_norm": {"bias": 2.824352264404297, "scale": 23.373172760009766}}, "24": {"attention": {"k_proj": {"bias": 0.46056002378463745, "kernel": 52.424072265625}, "out_proj": {"bias": 1.6070430278778076, "kernel": 52.50334167480469}, "q_proj": {"bias": 2.828113079071045, "kernel": 52.40515899658203}, "v_proj": {"bias": 0.5424190163612366, "kernel": 52.51116180419922}}, "feed_forward": {"intermediate_dense": {"bias": 2.2367913722991943, "kernel": 109.35035705566406}, "output_dense": {"bias": 1.3016372919082642, "kernel": 110.30095672607422}}, "final_layer_norm": {"bias": 2.83841872215271, "scale": 22.964658737182617}, "layer_norm": {"bias": 2.56215763092041, "scale": 22.983924865722656}}, "25": {"attention": {"k_proj": {"bias": 0.42509031295776367, "kernel": 52.730464935302734}, "out_proj": {"bias": 1.363797664642334, "kernel": 50.5806884765625}, "q_proj": {"bias": 2.9342763423919678, "kernel": 52.548744201660156}, "v_proj": {"bias": 0.6404213309288025, "kernel": 51.0885009765625}}, "feed_forward": {"intermediate_dense": {"bias": 2.1367578506469727, "kernel": 109.70021057128906}, "output_dense": {"bias": 1.1017413139343262, "kernel": 110.27072143554688}}, "final_layer_norm": {"bias": 2.5763301849365234, "scale": 23.494670867919922}, "layer_norm": {"bias": 2.683134078979492, "scale": 21.88357925415039}}, "26": {"attention": {"k_proj": {"bias": 0.4836847186088562, "kernel": 53.01764678955078}, "out_proj": {"bias": 1.2433912754058838, "kernel": 51.37077331542969}, "q_proj": {"bias": 2.943906784057617, "kernel": 52.80891036987305}, "v_proj": {"bias": 0.5064959526062012, "kernel": 52.004638671875}}, "feed_forward": {"intermediate_dense": {"bias": 2.2763516902923584, "kernel": 109.44652557373047}, "output_dense": {"bias": 1.0912110805511475, "kernel": 107.40899658203125}}, "final_layer_norm": {"bias": 2.1937994956970215, "scale": 22.433353424072266}, "layer_norm": {"bias": 2.497119903564453, "scale": 22.19057273864746}}, "27": {"attention": {"k_proj": {"bias": 0.5808594226837158, "kernel": 53.76898956298828}, "out_proj": {"bias": 1.5447406768798828, "kernel": 52.95805358886719}, "q_proj": {"bias": 2.703345775604248, "kernel": 53.69578552246094}, "v_proj": {"bias": 0.6748642325401306, "kernel": 53.388118743896484}}, "feed_forward": {"intermediate_dense": {"bias": 2.404933452606201, "kernel": 107.8713150024414}, "output_dense": {"bias": 0.9485896825790405, "kernel": 107.17198181152344}}, "final_layer_norm": {"bias": 2.5252954959869385, "scale": 21.88959503173828}, "layer_norm": {"bias": 2.6147172451019287, "scale": 23.32440948486328}}, "28": {"attention": {"k_proj": {"bias": 0.5901432037353516, "kernel": 54.482521057128906}, "out_proj": {"bias": 1.5367379188537598, "kernel": 53.31493377685547}, "q_proj": {"bias": 2.9472482204437256, "kernel": 54.1741943359375}, "v_proj": {"bias": 0.5131911039352417, "kernel": 53.759761810302734}}, "feed_forward": {"intermediate_dense": {"bias": 2.3475265502929688, "kernel": 107.87416076660156}, "output_dense": {"bias": 0.8224154710769653, "kernel": 109.1680908203125}}, "final_layer_norm": {"bias": 2.425306797027588, "scale": 22.337677001953125}, "layer_norm": {"bias": 2.0914058685302734, "scale": 23.993711471557617}}, "29": {"attention": {"k_proj": {"bias": 0.46781182289123535, "kernel": 51.12034606933594}, "out_proj": {"bias": 1.5021522045135498, "kernel": 55.685630798339844}, "q_proj": {"bias": 2.809702157974243, "kernel": 51.00274658203125}, "v_proj": {"bias": 0.4760415554046631, "kernel": 55.703304290771484}}, "feed_forward": {"intermediate_dense": {"bias": 2.297222137451172, "kernel": 108.01033020019531}, "output_dense": {"bias": 0.9597339630126953, "kernel": 113.12825012207031}}, "final_layer_norm": {"bias": 2.5980498790740967, "scale": 23.459980010986328}, "layer_norm": {"bias": 2.245180130004883, "scale": 25.39927864074707}}, "3": {"attention": {"k_proj": {"bias": 0.45006245374679565, "kernel": 52.03215789794922}, "out_proj": {"bias": 1.4254932403564453, "kernel": 48.60858917236328}, "q_proj": {"bias": 2.8560738563537598, "kernel": 52.312644958496094}, "v_proj": {"bias": 0.3246268630027771, "kernel": 48.768699645996094}}, "feed_forward": {"intermediate_dense": {"bias": 1.9663825035095215, "kernel": 104.83622741699219}, "output_dense": {"bias": 0.6984099745750427, "kernel": 94.07957458496094}}, "final_layer_norm": {"bias": 1.8095453977584839, "scale": 21.664737701416016}, "layer_norm": {"bias": 1.9017157554626465, "scale": 22.739452362060547}}, "30": {"attention": {"k_proj": {"bias": 0.5024805665016174, "kernel": 52.825706481933594}, "out_proj": {"bias": 1.3023658990859985, "kernel": 52.053871154785156}, "q_proj": {"bias": 2.907101631164551, "kernel": 52.91836166381836}, "v_proj": {"bias": 0.49308842420578003, "kernel": 52.49382019042969}}, "feed_forward": {"intermediate_dense": {"bias": 2.2399911880493164, "kernel": 108.17861938476562}, "output_dense": {"bias": 0.9140658378601074, "kernel": 112.09104919433594}}, "final_layer_norm": {"bias": 2.4926414489746094, "scale": 24.492368698120117}, "layer_norm": {"bias": 2.316732168197632, "scale": 24.931156158447266}}, "31": {"attention": {"k_proj": {"bias": 0.5412741899490356, "kernel": 51.240806579589844}, "out_proj": {"bias": 1.2333163022994995, "kernel": 52.19988250732422}, "q_proj": {"bias": 2.6581294536590576, "kernel": 51.346168518066406}, "v_proj": {"bias": 0.5469827651977539, "kernel": 52.432586669921875}}, "feed_forward": {"intermediate_dense": {"bias": 2.3097352981567383, "kernel": 106.72758483886719}, "output_dense": {"bias": 1.0891624689102173, "kernel": 109.24717712402344}}, "final_layer_norm": {"bias": 2.2962756156921387, "scale": 24.31252670288086}, "layer_norm": {"bias": 2.3430848121643066, "scale": 24.590187072753906}}, "32": {"attention": {"k_proj": {"bias": 0.4704548716545105, "kernel": 50.39933776855469}, "out_proj": {"bias": 1.2453913688659668, "kernel": 51.57465362548828}, "q_proj": {"bias": 2.8450098037719727, "kernel": 50.34847640991211}, "v_proj": {"bias": 0.419519305229187, "kernel": 51.972320556640625}}, "feed_forward": {"intermediate_dense": {"bias": 2.2569355964660645, "kernel": 105.33018493652344}, "output_dense": {"bias": 1.146787166595459, "kernel": 108.35574340820312}}, "final_layer_norm": {"bias": 2.31538724899292, "scale": 24.518985748291016}, "layer_norm": {"bias": 2.417579174041748, "scale": 24.991830825805664}}, "33": {"attention": {"k_proj": {"bias": 0.48390907049179077, "kernel": 50.28266143798828}, "out_proj": {"bias": 1.280959129333496, "kernel": 51.30557632446289}, "q_proj": {"bias": 2.998173713684082, "kernel": 50.253868103027344}, "v_proj": {"bias": 0.4416005611419678, "kernel": 51.71035385131836}}, "feed_forward": {"intermediate_dense": {"bias": 2.279946804046631, "kernel": 103.67684173583984}, "output_dense": {"bias": 1.1754591464996338, "kernel": 106.81501007080078}}, "final_layer_norm": {"bias": 2.257889747619629, "scale": 24.207740783691406}, "layer_norm": {"bias": 2.5865607261657715, "scale": 25.067882537841797}}, "34": {"attention": {"k_proj": {"bias": 0.45390456914901733, "kernel": 49.26523208618164}, "out_proj": {"bias": 1.5265402793884277, "kernel": 52.46200942993164}, "q_proj": {"bias": 2.913527011871338, "kernel": 49.268951416015625}, "v_proj": {"bias": 0.4023621678352356, "kernel": 52.534793853759766}}, "feed_forward": {"intermediate_dense": {"bias": 2.3745017051696777, "kernel": 102.21138000488281}, "output_dense": {"bias": 1.1227428913116455, "kernel": 105.72161102294922}}, "final_layer_norm": {"bias": 2.20273494720459, "scale": 23.640857696533203}, "layer_norm": {"bias": 2.615731716156006, "scale": 25.498104095458984}}, "35": {"attention": {"k_proj": {"bias": 0.5336894989013672, "kernel": 51.04902648925781}, "out_proj": {"bias": 1.4900906085968018, "kernel": 51.15006637573242}, "q_proj": {"bias": 2.573650360107422, "kernel": 51.323951721191406}, "v_proj": {"bias": 0.4889468550682068, "kernel": 51.176795959472656}}, "feed_forward": {"intermediate_dense": {"bias": 2.502547264099121, "kernel": 100.7535400390625}, "output_dense": {"bias": 1.0254027843475342, "kernel": 104.25007629394531}}, "final_layer_norm": {"bias": 2.295243263244629, "scale": 23.61981964111328}, "layer_norm": {"bias": 2.499863862991333, "scale": 26.093618392944336}}, "36": {"attention": {"k_proj": {"bias": 0.4473738670349121, "kernel": 48.319740295410156}, "out_proj": {"bias": 1.5177178382873535, "kernel": 52.26478958129883}, "q_proj": {"bias": 2.6124308109283447, "kernel": 48.236427307128906}, "v_proj": {"bias": 0.39316776394844055, "kernel": 52.66499328613281}}, "feed_forward": {"intermediate_dense": {"bias": 2.3669564723968506, "kernel": 99.60517120361328}, "output_dense": {"bias": 1.0260361433029175, "kernel": 103.68067932128906}}, "final_layer_norm": {"bias": 2.044203519821167, "scale": 24.13540267944336}, "layer_norm": {"bias": 2.2894434928894043, "scale": 25.63806915283203}}, "37": {"attention": {"k_proj": {"bias": 0.6240901350975037, "kernel": 47.300960540771484}, "out_proj": {"bias": 1.7604304552078247, "kernel": 52.189971923828125}, "q_proj": {"bias": 2.3819239139556885, "kernel": 47.316253662109375}, "v_proj": {"bias": 0.38518577814102173, "kernel": 52.32656478881836}}, "feed_forward": {"intermediate_dense": {"bias": 2.2694783210754395, "kernel": 98.56558227539062}, "output_dense": {"bias": 1.0087021589279175, "kernel": 103.13001251220703}}, "final_layer_norm": {"bias": 1.7903382778167725, "scale": 24.500085830688477}, "layer_norm": {"bias": 2.238807439804077, "scale": 25.563133239746094}}, "38": {"attention": {"k_proj": {"bias": 0.722891092300415, "kernel": 45.45281219482422}, "out_proj": {"bias": 1.4463956356048584, "kernel": 51.495086669921875}, "q_proj": {"bias": 2.2622861862182617, "kernel": 45.454437255859375}, "v_proj": {"bias": 0.42936205863952637, "kernel": 51.565425872802734}}, "feed_forward": {"intermediate_dense": {"bias": 2.2012171745300293, "kernel": 96.4342041015625}, "output_dense": {"bias": 0.9817801713943481, "kernel": 101.30387115478516}}, "final_layer_norm": {"bias": 1.7825044393539429, "scale": 25.214679718017578}, "layer_norm": {"bias": 2.4084482192993164, "scale": 26.425636291503906}}, "39": {"attention": {"k_proj": {"bias": 0.7209377884864807, "kernel": 45.24742889404297}, "out_proj": {"bias": 1.7097458839416504, "kernel": 51.329002380371094}, "q_proj": {"bias": 2.1152358055114746, "kernel": 45.53700637817383}, "v_proj": {"bias": 0.4246286153793335, "kernel": 51.27728271484375}}, "feed_forward": {"intermediate_dense": {"bias": 2.177624225616455, "kernel": 94.23902893066406}, "output_dense": {"bias": 1.0458879470825195, "kernel": 101.17718505859375}}, "final_layer_norm": {"bias": 1.753927230834961, "scale": 25.785865783691406}, "layer_norm": {"bias": 2.33198881149292, "scale": 26.942312240600586}}, "4": {"attention": {"k_proj": {"bias": 0.44548165798187256, "kernel": 54.6234130859375}, "out_proj": {"bias": 1.652343988418579, "kernel": 50.16497039794922}, "q_proj": {"bias": 2.615248918533325, "kernel": 54.93098068237305}, "v_proj": {"bias": 0.34892427921295166, "kernel": 50.320098876953125}}, "feed_forward": {"intermediate_dense": {"bias": 1.9531786441802979, "kernel": 104.50968933105469}, "output_dense": {"bias": 0.8575068712234497, "kernel": 95.54541015625}}, "final_layer_norm": {"bias": 1.9920002222061157, "scale": 21.200180053710938}, "layer_norm": {"bias": 2.054612159729004, "scale": 23.620338439941406}}, "40": {"attention": {"k_proj": {"bias": 0.6590453386306763, "kernel": 44.198089599609375}, "out_proj": {"bias": 1.6252505779266357, "kernel": 49.55699920654297}, "q_proj": {"bias": 1.9674756526947021, "kernel": 44.89208221435547}, "v_proj": {"bias": 0.4587768614292145, "kernel": 49.23614501953125}}, "feed_forward": {"intermediate_dense": {"bias": 2.0333969593048096, "kernel": 92.16896057128906}, "output_dense": {"bias": 1.087776780128479, "kernel": 98.40738677978516}}, "final_layer_norm": {"bias": 1.7852704524993896, "scale": 25.04292106628418}, "layer_norm": {"bias": 2.2756104469299316, "scale": 26.40799903869629}}, "41": {"attention": {"k_proj": {"bias": 1.7133712768554688, "kernel": 41.96858596801758}, "out_proj": {"bias": 1.3790823221206665, "kernel": 51.25593566894531}, "q_proj": {"bias": 1.71382737159729, "kernel": 42.56317901611328}, "v_proj": {"bias": 0.4695759415626526, "kernel": 50.369529724121094}}, "feed_forward": {"intermediate_dense": {"bias": 2.110393524169922, "kernel": 88.92567443847656}, "output_dense": {"bias": 1.1446669101715088, "kernel": 97.37409973144531}}, "final_layer_norm": {"bias": 2.23917293548584, "scale": 28.507984161376953}, "layer_norm": {"bias": 2.22525691986084, "scale": 28.246891021728516}}, "42": {"attention": {"k_proj": {"bias": 0.8601109981536865, "kernel": 38.31235885620117}, "out_proj": {"bias": 1.4427157640457153, "kernel": 45.07648849487305}, "q_proj": {"bias": 1.549715280532837, "kernel": 39.524009704589844}, "v_proj": {"bias": 0.6933339834213257, "kernel": 43.49076461791992}}, "feed_forward": {"intermediate_dense": {"bias": 1.9107489585876465, "kernel": 88.009765625}, "output_dense": {"bias": 1.1978566646575928, "kernel": 95.7593994140625}}, "final_layer_norm": {"bias": 1.9227323532104492, "scale": 29.817535400390625}, "layer_norm": {"bias": 1.6761282682418823, "scale": 26.810440063476562}}, "43": {"attention": {"k_proj": {"bias": 1.247081995010376, "kernel": 34.694725036621094}, "out_proj": {"bias": 1.4174811840057373, "kernel": 41.36320495605469}, "q_proj": {"bias": 1.3773530721664429, "kernel": 35.38981628417969}, "v_proj": {"bias": 0.5787136554718018, "kernel": 39.29212951660156}}, "feed_forward": {"intermediate_dense": {"bias": 1.8936835527420044, "kernel": 87.073974609375}, "output_dense": {"bias": 0.9419379234313965, "kernel": 93.74283599853516}}, "final_layer_norm": {"bias": 1.99924898147583, "scale": 32.0491943359375}, "layer_norm": {"bias": 1.7940990924835205, "scale": 25.131242752075195}}, "44": {"attention": {"k_proj": {"bias": 2.5188145637512207, "kernel": 35.16314697265625}, "out_proj": {"bias": 1.1667875051498413, "kernel": 45.019126892089844}, "q_proj": {"bias": 1.317202091217041, "kernel": 35.58300018310547}, "v_proj": {"bias": 0.38874924182891846, "kernel": 44.14718246459961}}, "feed_forward": {"intermediate_dense": {"bias": 1.9462898969650269, "kernel": 86.07953643798828}, "output_dense": {"bias": 0.859266996383667, "kernel": 91.58765411376953}}, "final_layer_norm": {"bias": 2.0454273223876953, "scale": 34.2881965637207}, "layer_norm": {"bias": 1.6815991401672363, "scale": 25.14142608642578}}, "45": {"attention": {"k_proj": {"bias": 2.081407308578491, "kernel": 34.86139678955078}, "out_proj": {"bias": 1.0356104373931885, "kernel": 48.59937286376953}, "q_proj": {"bias": 1.402512788772583, "kernel": 35.03264617919922}, "v_proj": {"bias": 0.4231463074684143, "kernel": 48.76853942871094}}, "feed_forward": {"intermediate_dense": {"bias": 2.016927719116211, "kernel": 82.93773651123047}, "output_dense": {"bias": 0.9764893054962158, "kernel": 87.24796295166016}}, "final_layer_norm": {"bias": 1.9180456399917603, "scale": 33.143672943115234}, "layer_norm": {"bias": 1.5726068019866943, "scale": 23.782546997070312}}, "46": {"attention": {"k_proj": {"bias": 1.5659263134002686, "kernel": 35.878021240234375}, "out_proj": {"bias": 0.8182340264320374, "kernel": 51.16078186035156}, "q_proj": {"bias": 1.5642974376678467, "kernel": 36.18907165527344}, "v_proj": {"bias": 0.4092414081096649, "kernel": 51.89159393310547}}, "feed_forward": {"intermediate_dense": {"bias": 2.0093321800231934, "kernel": 77.47581481933594}, "output_dense": {"bias": 1.1406863927841187, "kernel": 77.73695373535156}}, "final_layer_norm": {"bias": 1.8108854293823242, "scale": 28.70657730102539}, "layer_norm": {"bias": 1.3991491794586182, "scale": 22.808137893676758}}, "47": {"attention": {"k_proj": {"bias": 0.6173280477523804, "kernel": 38.678985595703125}, "out_proj": {"bias": 0.6758822202682495, "kernel": 46.45281219482422}, "q_proj": {"bias": 1.7084776163101196, "kernel": 39.426841735839844}, "v_proj": {"bias": 0.4932914674282074, "kernel": 47.617279052734375}}, "feed_forward": {"intermediate_dense": {"bias": 1.986911654472351, "kernel": 75.47482299804688}, "output_dense": {"bias": 0.6346586346626282, "kernel": 72.82707214355469}}, "final_layer_norm": {"bias": 1.1888140439987183, "scale": 23.650447845458984}, "layer_norm": {"bias": 1.2521969079971313, "scale": 20.66573715209961}}, "5": {"attention": {"k_proj": {"bias": 0.42588678002357483, "kernel": 50.1945686340332}, "out_proj": {"bias": 1.6038882732391357, "kernel": 51.2144889831543}, "q_proj": {"bias": 2.7522244453430176, "kernel": 50.37500762939453}, "v_proj": {"bias": 0.3343381881713867, "kernel": 51.71652603149414}}, "feed_forward": {"intermediate_dense": {"bias": 1.8887722492218018, "kernel": 104.60663604736328}, "output_dense": {"bias": 0.8976269960403442, "kernel": 94.77360534667969}}, "final_layer_norm": {"bias": 2.1965675354003906, "scale": 21.37998390197754}, "layer_norm": {"bias": 2.0435237884521484, "scale": 22.437192916870117}}, "6": {"attention": {"k_proj": {"bias": 0.4843112528324127, "kernel": 51.87700653076172}, "out_proj": {"bias": 1.5925445556640625, "kernel": 50.83113479614258}, "q_proj": {"bias": 2.7889723777770996, "kernel": 52.3514404296875}, "v_proj": {"bias": 0.3247200846672058, "kernel": 51.107398986816406}}, "feed_forward": {"intermediate_dense": {"bias": 1.8638136386871338, "kernel": 103.7142333984375}, "output_dense": {"bias": 0.752193808555603, "kernel": 94.57742309570312}}, "final_layer_norm": {"bias": 2.5145251750946045, "scale": 20.836563110351562}, "layer_norm": {"bias": 2.0285890102386475, "scale": 23.156789779663086}}, "7": {"attention": {"k_proj": {"bias": 0.5048109889030457, "kernel": 51.46453094482422}, "out_proj": {"bias": 1.4398455619812012, "kernel": 51.139068603515625}, "q_proj": {"bias": 2.550907611846924, "kernel": 51.92047119140625}, "v_proj": {"bias": 0.42719271779060364, "kernel": 50.953285217285156}}, "feed_forward": {"intermediate_dense": {"bias": 1.8739449977874756, "kernel": 103.49991607666016}, "output_dense": {"bias": 0.5876641273498535, "kernel": 94.39387512207031}}, "final_layer_norm": {"bias": 2.416801929473877, "scale": 21.010677337646484}, "layer_norm": {"bias": 1.9788501262664795, "scale": 22.19708824157715}}, "8": {"attention": {"k_proj": {"bias": 0.49711495637893677, "kernel": 51.12122344970703}, "out_proj": {"bias": 1.2548246383666992, "kernel": 51.65118408203125}, "q_proj": {"bias": 2.541980504989624, "kernel": 51.03407287597656}, "v_proj": {"bias": 0.35420340299606323, "kernel": 51.662872314453125}}, "feed_forward": {"intermediate_dense": {"bias": 1.9298467636108398, "kernel": 103.21916198730469}, "output_dense": {"bias": 0.5482766628265381, "kernel": 93.97574615478516}}, "final_layer_norm": {"bias": 2.3526053428649902, "scale": 20.7393856048584}, "layer_norm": {"bias": 1.9221248626708984, "scale": 22.40435028076172}}, "9": {"attention": {"k_proj": {"bias": 0.5231171250343323, "kernel": 52.01068878173828}, "out_proj": {"bias": 1.4968843460083008, "kernel": 52.671897888183594}, "q_proj": {"bias": 2.4629459381103516, "kernel": 52.26807403564453}, "v_proj": {"bias": 0.38445231318473816, "kernel": 52.86597442626953}}, "feed_forward": {"intermediate_dense": {"bias": 2.026733875274658, "kernel": 101.98421478271484}, "output_dense": {"bias": 0.6828575134277344, "kernel": 94.36962890625}}, "final_layer_norm": {"bias": 2.325080156326294, "scale": 20.160720825195312}, "layer_norm": {"bias": 2.0236480236053467, "scale": 24.083864212036133}}}, "pos_conv_embed": {"conv": {"bias": 5.847014427185059, "weight_g": 9.12463665008545, "weight_v": 93.52015686035156}}}, "feature_extractor": {"conv_layers": {"0": {"conv": {"bias": 2.0290679931640625, "kernel": 20.55536460876465}, "layer_norm": {"bias": 4.550922393798828, "scale": 16.167570114135742}}, "1": {"conv": {"bias": 1.7790228128433228, "kernel": 51.24136734008789}, "layer_norm": {"bias": 5.962646961212158, "scale": 23.268157958984375}}, "2": {"conv": {"bias": 1.140576720237732, "kernel": 46.50312042236328}, "layer_norm": {"bias": 4.176670551300049, "scale": 20.370853424072266}}, "3": {"conv": {"bias": 0.6725863218307495, "kernel": 44.397525787353516}, "layer_norm": {"bias": 3.888174533843994, "scale": 17.53795051574707}}, "4": {"conv": {"bias": 0.6373162269592285, "kernel": 41.314056396484375}, "layer_norm": {"bias": 2.385471820831299, "scale": 16.34571647644043}}, "5": {"conv": {"bias": 0.5147221684455872, "kernel": 37.479759216308594}, "layer_norm": {"bias": 2.020900011062622, "scale": 17.064470291137695}}, "6": {"conv": {"bias": 0.4947893023490906, "kernel": 40.64780044555664}, "layer_norm": {"bias": 0.5876954793930054, "scale": 19.058603286743164}}}}, "feature_projection": {"layer_norm": {"bias": 6.376383304595947, "scale": 16.443069458007812}, "projection": {"bias": 1.8670344352722168, "kernel": 37.218414306640625}}, "masked_spec_embed": 11.914372444152832}}, "train/learning_rate": 1.9151924789184704e-05, "train/loss": 0.204779714345932, "train/param_norm": 1241.662353515625, "_runtime": 4765, "_timestamp": 1660121820, "_step": 275600, "_wandb": {"runtime": 4766}}
wandb/run-20220810_073735-23avj35z/logs/debug-internal.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:321db68741142aebf6da6dfd07396d57f1844a38e8782fb191cb8b2f9d6ad8f3
3
+ size 182857
wandb/run-20220810_073735-23avj35z/logs/debug.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ac7e803317e52439f444d7689e70325ecbb39546789a4dccfc840ec06a3de97
3
+ size 6204
wandb/run-20220810_073735-23avj35z/run-23avj35z.wandb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2c7bcca71ea7ef338baa85a6158b1edd93877dbbb7c2d16bb122668a85532e88
3
+ size 772668
wandb/run-20220810_111559-290849gb/files/code/run_flax_speech_recognition_ctc.py ADDED
@@ -0,0 +1,1631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2022 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the Flax library models for connectionist temporal classification (CTC) speech recognition.
18
+ """
19
+ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
+
21
+ import logging
22
+ import math
23
+ import os
24
+ import re
25
+ import sys
26
+ import time
27
+ from dataclasses import dataclass, field
28
+ from pathlib import Path
29
+ from typing import Any, Callable, Dict, List, Optional, Union
30
+
31
+ import datasets
32
+ import numpy as np
33
+ from datasets import DatasetDict, load_dataset, load_metric
34
+ from tqdm import tqdm
35
+
36
+ import flax
37
+ import jax
38
+ import jax.numpy as jnp
39
+ import optax
40
+ import transformers
41
+ import wandb as wandb
42
+ from flax import core, jax_utils, struct, traverse_util
43
+ from flax.jax_utils import unreplicate, pad_shard_unpad
44
+ from flax.training.common_utils import get_metrics, shard, shard_prng_key
45
+ from huggingface_hub import Repository
46
+ from models import Wav2Vec2Config, FlaxWav2Vec2ForCTC
47
+ from optax._src import linear_algebra
48
+ from transformers import (
49
+ AutoFeatureExtractor,
50
+ AutoProcessor,
51
+ AutoTokenizer,
52
+ HfArgumentParser,
53
+ TrainingArguments,
54
+ is_tensorboard_available,
55
+ set_seed,
56
+ )
57
+ from transformers.file_utils import get_full_repo_name
58
+ from transformers.utils import check_min_version
59
+ from transformers.utils.versions import require_version
60
+
61
+
62
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
63
+ check_min_version("4.17.0.dev0")
64
+
65
+ require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
66
+
67
+ logger = logging.getLogger(__name__)
68
+
69
+
70
+ @flax.struct.dataclass
71
+ class ModelArguments:
72
+ """
73
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
74
+ """
75
+
76
+ model_name_or_path: str = field(
77
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
78
+ )
79
+ config_name: Optional[str] = field(
80
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
81
+ )
82
+ tokenizer_name: Optional[str] = field(
83
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
84
+ )
85
+ feature_extractor_name: Optional[str] = field(
86
+ default=None, metadata={"help": "feature extractor name or path if not the same as model_name"}
87
+ )
88
+ cache_dir: Optional[str] = field(
89
+ default=None,
90
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
91
+ )
92
+ use_fast_tokenizer: bool = field(
93
+ default=True,
94
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
95
+ )
96
+ model_revision: str = field(
97
+ default="main",
98
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
99
+ )
100
+ use_auth_token: bool = field(
101
+ default=False,
102
+ metadata={
103
+ "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
104
+ "with private models)."
105
+ },
106
+ )
107
+ freeze_feature_encoder: bool = field(
108
+ default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
109
+ )
110
+ attention_dropout: float = field(
111
+ default=0.0, metadata={"help": "The dropout ratio for the attention probabilities."}
112
+ )
113
+ activation_dropout: float = field(
114
+ default=0.1,
115
+ metadata={
116
+ "help": "The hidden activation dropout probability in the embeddings, encoder, and pooler."
117
+ },
118
+ )
119
+ hidden_dropout: float = field(
120
+ default=0.1,
121
+ metadata={
122
+ "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler."
123
+ },
124
+ )
125
+ feat_proj_dropout: float = field(
126
+ default=0.0,
127
+ metadata={
128
+ "help": "The feat proj dropout probability for feature encoder representations."
129
+ },
130
+ )
131
+ final_dropout: float = field(
132
+ default=0.0,
133
+ metadata={"help": "The dropout probability for the final projection layer."},
134
+ )
135
+ mask_time_prob: float = field(
136
+ default=0.1,
137
+ metadata={
138
+ "help": "The spec aug dropout probability for feature encoder representations."
139
+ },
140
+ )
141
+ mask_time_length: int = field(
142
+ default=10,
143
+ metadata={"help": "Length of vector span to mask along the time axis."},
144
+ )
145
+ mask_feature_prob: float = field(
146
+ default=0.0,
147
+ metadata={
148
+ "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector"
149
+ "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis."
150
+ },
151
+ )
152
+ mask_feature_length: int = field(
153
+ default=10,
154
+ metadata={"help": "Length of vector span to mask along the feature axis."},
155
+ )
156
+ layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."})
157
+ ctc_loss_reduction: Optional[str] = field(
158
+ default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
159
+ )
160
+ ctc_zero_infinity: Optional[bool] = field(
161
+ default=False, metadata={"help": "If True, will try yo aboud the CTC loss goinf to infinity."}
162
+ )
163
+
164
+
165
+ @flax.struct.dataclass
166
+ class DataTrainingArguments:
167
+ """
168
+ Arguments pertaining to what data we are going to input our model for training and eval.
169
+ """
170
+
171
+ dataset_name: str = field(
172
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
173
+ )
174
+ dataset_config_name: Optional[str] = field(
175
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
176
+ )
177
+ text_column: Optional[str] = field(
178
+ default=None,
179
+ metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
180
+ )
181
+ dataset_cache_dir: Optional[str] = field(
182
+ default=None, metadata={"help": "Path to cache directory for saving and loading datasets"}
183
+ )
184
+ overwrite_cache: bool = field(
185
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
186
+ )
187
+ preprocessing_num_workers: Optional[int] = field(
188
+ default=None,
189
+ metadata={"help": "The number of processes to use for the preprocessing."},
190
+ )
191
+ max_train_samples: Optional[int] = field(
192
+ default=None,
193
+ metadata={
194
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
195
+ "value if set."
196
+ },
197
+ )
198
+ max_eval_samples: Optional[int] = field(
199
+ default=None,
200
+ metadata={
201
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
202
+ "value if set."
203
+ },
204
+ )
205
+ max_test_samples: Optional[int] = field(
206
+ default=None,
207
+ metadata={
208
+ "help": "For debugging purposes or quicker training, truncate the number of test examples to this "
209
+ "value if set."
210
+ },
211
+ )
212
+ audio_column_name: str = field(
213
+ default="audio",
214
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
215
+ )
216
+ text_column_name: str = field(
217
+ default="text",
218
+ metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
219
+ )
220
+ max_duration_in_seconds: float = field(
221
+ default=20.0,
222
+ metadata={
223
+ "help": "Truncate audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`"
224
+ },
225
+ )
226
+ min_duration_in_seconds: float = field(
227
+ default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
228
+ )
229
+ max_label_length: Optional[int] = field(
230
+ default=512,
231
+ metadata={
232
+ "help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
233
+ "than this will be filtered."
234
+ },
235
+ )
236
+ min_label_length: Optional[int] = field(
237
+ default=2,
238
+ metadata={
239
+ "help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
240
+ "than this will be filtered."
241
+ },
242
+ )
243
+ pad_input_to_multiple_of: Optional[int] = field(
244
+ default=32000,
245
+ metadata={
246
+ "help": "If set will pad the input sequence to a multiple of the provided value. "
247
+ "This is important to avoid triggering recompilations on TPU."
248
+ },
249
+ )
250
+ pad_target_to_multiple_of: Optional[int] = field(
251
+ default=None,
252
+ metadata={
253
+ "help": "If set will pad the target sequence to a multiple of the provided value. "
254
+ "This is important to avoid triggering recompilations on TPU."
255
+ },
256
+ )
257
+ preprocessing_only: bool = field(
258
+ default=False,
259
+ metadata={
260
+ "help": "Whether to only do data preprocessing and skip training. "
261
+ "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
262
+ "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
263
+ "so that the cached datasets can consequently be loaded in distributed training"
264
+ },
265
+ )
266
+ train_split_name: str = field(
267
+ default="train",
268
+ metadata={
269
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
270
+ },
271
+ )
272
+ eval_split_name: str = field(
273
+ default="validation",
274
+ metadata={
275
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
276
+ },
277
+ )
278
+ do_lower_case: bool = field(
279
+ default=True,
280
+ metadata={"help": "Whether the target text should be lower cased."},
281
+ )
282
+ wandb_project: str = field(
283
+ default="flax-speech-recognition-ctc",
284
+ metadata={"help": "The name of the wandb project."},
285
+ )
286
+ wandb_name: str = field(
287
+ default=None,
288
+ metadata={"help": "The name of the wandb run."},
289
+ )
290
+ wandb_job_type: str = field(
291
+ default="CTC",
292
+ metadata={"help": "The name of the wandb job type."},
293
+ )
294
+ test_split_name: str = field(
295
+ default="test",
296
+ metadata={"help": "The name of the test data set split to use (via the datasets library). Defaults to 'test'"},
297
+ )
298
+ remove_punctuation: bool = field(
299
+ default=False, metadata={"help": "Whether or not to remove punctuation during training."}
300
+ )
301
+ skip_steps: Optional[int] = field(
302
+ default=0,
303
+ metadata={
304
+ "help": "Skip this number of steps. Useful to continue training"
305
+ },
306
+ )
307
+
308
+
309
+ # @flax.struct.dataclass
310
+ @dataclass
311
+ class FlaxTrainingArguments(TrainingArguments):
312
+ precision: str = field(
313
+ default="full",
314
+ metadata={
315
+ "help": "Whether to enable mixed-precision training. If true, the optimizer is stored in half-precision (bfloat16) and computations are executed in half-precision"
316
+ "**Note that this only specifies the dtype of the computation and optimizer state. It does not influence the dtype of model parameters.**"
317
+ },
318
+ )
319
+ matmul_precision: str = field(
320
+ default="default",
321
+ metadata={
322
+ "help": "Default floating-point precision of internal computations used in TPU matrix multiplications and convolutions. "
323
+ "This configuration option controls the default precision for JAX operations that take an optional precision argument (e.g. `lax.conv_general_dilated` and `lax.dot`). "
324
+ "This configuration option does not change the behaviours of such calls with explicit precision arguments; "
325
+ "it only changes the behaviors of calls with no such argument provided. "
326
+ "One of `['highest', 'float32', 'high', 'bfloat16_3x', 'default', 'bfloat16', 'fastest', None]`."
327
+ },
328
+ )
329
+ multisteps: bool = field(
330
+ default=False,
331
+ metadata={
332
+ "help": "Whether to use Optax MultiSteps for gradient accumulation. If `False` (default) and `gradient_accumulation_steps > 1`, "
333
+ "a custom gradient accumulation implementation will be employed."
334
+ },
335
+ )
336
+
337
+
338
+ def to_fp32(t):
339
+ return jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
340
+
341
+
342
+ def to_bf16(t):
343
+ return jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t)
344
+
345
+
346
+ class MixedPrecisionTrainState(struct.PyTreeNode):
347
+ """Train state for use with a single Optax optimizer.
348
+ Adapted from flax train_state https://github.com/google/flax/blob/main/flax/training/train_state.py
349
+
350
+ Synopsis::
351
+
352
+ state = TrainState.create(
353
+ apply_fn=model.apply,
354
+ params=variables['params'],
355
+ tx=tx)
356
+ grad_fn = jax.grad(make_loss_fn(state.apply_fn))
357
+ for batch in data:
358
+ grads = grad_fn(state.params, batch)
359
+ state = state.apply_gradients(grads=grads)
360
+
361
+ Args:
362
+ step: Counter starts at 0 and is incremented by every call to
363
+ `.apply_gradients()`.
364
+ apply_fn: Usually set to `model.apply()`. Kept in this dataclass for
365
+ convenience to have a shorter params list for the `train_step()` function
366
+ in your training loop.
367
+ params: The parameters to be updated by `tx` and used by `apply_fn`.
368
+ tx: An Optax gradient transformation.
369
+ opt_state: The state for `tx`.
370
+ dropout_rng: PRNG key for stochastic operations.
371
+ bf16: Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training.
372
+ """
373
+
374
+ step: int
375
+ apply_fn: Callable = struct.field(pytree_node=False)
376
+ get_attention_mask_fn: Callable = struct.field(pytree_node=False)
377
+ params: core.FrozenDict[str, Any]
378
+ tx: optax.GradientTransformation = struct.field(pytree_node=False)
379
+ opt_state: optax.OptState
380
+ dropout_rng: jnp.ndarray
381
+ max_grad_norm: Optional[float] = 1.0
382
+
383
+ def apply_gradients(self, *, grads, to_dtype, **kwargs):
384
+ """Updates `step`, `params`, `opt_state` and `**kwargs` in return value.
385
+
386
+ Note that internally this function calls `.tx.update()` followed by a call
387
+ to `optax.apply_updates()` to update `params` and `opt_state`.
388
+
389
+ Args:
390
+ grads: Gradients that have the same pytree structure as `.params`.
391
+ **kwargs: Additional dataclass attributes that should be `.replace()`-ed.
392
+
393
+ Returns:
394
+ An updated instance of `self` with `step` incremented by one, `params`
395
+ and `opt_state` updated by applying `grads`, and additional attributes
396
+ replaced as specified by `kwargs`.
397
+ """
398
+
399
+ # clip gradients by global l2 norm
400
+ casted_max_grad_norm = to_dtype(self.max_grad_norm)
401
+ g_norm = linear_algebra.global_norm(grads)
402
+ g_norm = jnp.maximum(casted_max_grad_norm, g_norm)
403
+ grads = jax.tree_util.tree_map(lambda t: (t / g_norm) * casted_max_grad_norm, grads)
404
+
405
+ # perform update step in fp32 and subsequently downcast optimizer states if mixed precision training
406
+ # grads and opt_state in bf16 (need to upcast), params in fp32 (leave as is)
407
+ updates, new_opt_state = self.tx.update(to_fp32(grads), to_fp32(self.opt_state), self.params)
408
+
409
+ new_params = optax.apply_updates(self.params, updates)
410
+ return self.replace(
411
+ step=self.step + 1,
412
+ params=new_params,
413
+ opt_state=to_dtype(new_opt_state),
414
+ **kwargs,
415
+ )
416
+
417
+ @classmethod
418
+ def create(cls, *, apply_fn, params, tx, to_dtype, **kwargs):
419
+ """Creates a new instance with `step=0` and initialized `opt_state`."""
420
+ # downcast optimizer state to bf16 if mixed-precision training
421
+ opt_state = tx.init(to_dtype(params)) if tx is not None else None
422
+ return cls(
423
+ step=0,
424
+ apply_fn=apply_fn,
425
+ params=params,
426
+ tx=tx,
427
+ opt_state=opt_state,
428
+ **kwargs,
429
+ )
430
+
431
+ def replicate(self):
432
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
433
+
434
+
435
+ @flax.struct.dataclass
436
+ class FlaxDataCollatorSpeechSeq2SeqWithPadding:
437
+ """
438
+ Data collator that will dynamically pad the inputs received.
439
+ Args:
440
+ processor ([`Wav2Vec2Processor`])
441
+ The processor used for proccessing the data.
442
+ decoder_start_token_id (:obj: `int`)
443
+ The begin-of-sentence of the decoder.
444
+ input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
445
+ Select a strategy to pad the returned input sequences (according to the model's padding side and padding index)
446
+ among:
447
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
448
+ sequence if provided).
449
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
450
+ maximum acceptable input length for the model if that argument is not provided.
451
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
452
+ different lengths).
453
+ target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
454
+ Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
455
+ See above for details.
456
+ max_input_length (:obj:`float`, `optional`):
457
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
458
+ pad_input_to_multiple_of (:obj:`int`, `optional`):
459
+ If set will pad the input sequence to a multiple of the provided value.
460
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
461
+ 7.5 (Volta).
462
+ pad_target_to_multiple_of (:obj:`int`, `optional`):
463
+ If set will pad the target sequence to a multiple of the provided value.
464
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
465
+ 7.5 (Volta).
466
+ """
467
+
468
+ processor: Any
469
+ input_padding: Union[bool, str] = "longest"
470
+ label_padding: Union[bool, str] = "max_length"
471
+ pad_input_to_multiple_of: Optional[int] = None
472
+ pad_to_multiple_of_label: Optional[int] = None
473
+ max_input_length: Optional[float] = None
474
+ max_label_length: Optional[float] = None
475
+
476
+ def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
477
+ # split inputs and labels since they have to be of different lengths and need
478
+ # different padding methods
479
+ input_features = [{"input_values": feature["input_values"]} for feature in features]
480
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
481
+
482
+ # reformat list to dict and set to pytorch format
483
+ batch = self.processor.feature_extractor.pad(
484
+ input_features,
485
+ max_length=self.max_input_length,
486
+ padding=self.input_padding,
487
+ pad_to_multiple_of=self.pad_input_to_multiple_of,
488
+ return_tensors="np",
489
+ )
490
+
491
+ labels_batch = self.processor.tokenizer.pad(
492
+ label_features,
493
+ max_length=self.max_label_length,
494
+ padding=self.label_padding,
495
+ pad_to_multiple_of=self.pad_to_multiple_of_label,
496
+ return_tensors="np",
497
+ )
498
+
499
+ labels = labels_batch["input_ids"]
500
+ labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1))
501
+ labels = labels.filled(fill_value=-100)
502
+
503
+ batch["labels"] = labels
504
+
505
+ return batch
506
+
507
+
508
+ def get_grouped_indices(
509
+ dataset, batch_size: int, rng: Optional[List[int]] = None, mega_batch_mult: Optional[int] = None
510
+ ) -> np.array:
511
+ """
512
+ Adapted from the `get_length_grouped_indices` function in the PyTorch Trainer utils file (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_pt_utils.py#L486)
513
+ Function that returns a list of indices in which each slice of `batch_size` consecutive indices correspond to elements of similar
514
+ lengths. To do this, the indices are:
515
+
516
+ - randomly permuted (if a JAX rng is specified)
517
+ - grouped in mega-batches of size `mega_batch_mult * batch_size`
518
+ - sorted by length in each mega-batch
519
+
520
+ The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
521
+ maximum length placed first, so that an OOM happens sooner rather than later.
522
+ """
523
+ lengths = dataset["input_length"]
524
+
525
+ # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
526
+ if mega_batch_mult is None:
527
+ mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
528
+ # Just in case, for tiny datasets
529
+ if mega_batch_mult == 0:
530
+ mega_batch_mult = 1
531
+
532
+ # We need to use JAX for the random permutation as the PRNG key will be set based on the seed outside of the sampler.
533
+ num_samples = len(lengths)
534
+ indices = jax.random.permutation(rng, np.arange(num_samples)) if rng is not None else np.arange(num_samples)
535
+
536
+ megabatch_size = mega_batch_mult * batch_size
537
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
538
+ megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches]
539
+
540
+ # The rest is to get the biggest batch first.
541
+ # Since each megabatch is sorted by descending length, the longest element is the first
542
+ megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
543
+ max_idx = np.argmax(megabatch_maximums).item()
544
+ # Switch to put the longest batch in first position
545
+ # (note that this is different to the PT grouped sampler in which we only put the longest element in the first position, and not its batch)
546
+ megabatches[0], megabatches[max_idx] = megabatches[max_idx], megabatches[0]
547
+
548
+ megabatches = np.array([i for megabatch in megabatches for i in megabatch])
549
+
550
+ return megabatches
551
+
552
+
553
+ def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray:
554
+ """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
555
+ the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned."""
556
+ num_samples = len(samples_idx)
557
+ if drop_last:
558
+ samples_to_remove = num_samples % batch_size
559
+ if samples_to_remove != 0:
560
+ samples_idx = samples_idx[:-samples_to_remove]
561
+ sections_split = num_samples // batch_size
562
+ samples_idx = samples_idx.reshape((sections_split, batch_size))
563
+ else:
564
+ sections_split = math.ceil(num_samples / batch_size)
565
+ samples_idx = np.array_split(samples_idx, sections_split)
566
+ return samples_idx
567
+
568
+
569
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
570
+ summary_writer.scalar("train_time", train_time, step)
571
+
572
+ train_metrics = get_metrics(train_metrics)
573
+ for key, vals in train_metrics.items():
574
+ tag = f"train_{key}"
575
+ for i, val in enumerate(vals):
576
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
577
+
578
+
579
+ def write_eval_metric(summary_writer, eval_metrics, step, pred_str=None):
580
+ for metric_name, value in eval_metrics.items():
581
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
582
+
583
+ if pred_str is not None:
584
+ # write output actual predictions for debugging
585
+ summary_writer.text("eval_predictions", "\n".join(pred_str), step)
586
+
587
+
588
+ def write_wandb_log(metrics, step, prefix=None):
589
+ if jax.process_index() == 0:
590
+ log_metrics = {}
591
+ for k, v in metrics.items():
592
+ if "layer" in k:
593
+ log_metrics[f"{k}/"] = v
594
+ elif prefix is not None:
595
+ log_metrics[f"{prefix}/{k}"] = v
596
+ else:
597
+ log_metrics[k] = v
598
+ wandb.log(log_metrics, step)
599
+
600
+
601
+ def write_wandb_pred(pred_str, label_str, step, num_log=50, prefix="eval"):
602
+ if jax.process_index() == 0:
603
+ # convert str data to a wandb compatible format
604
+ str_data = [[label_str[i], pred_str[i]] for i in range(len(pred_str))]
605
+ # we'll log the first 50 predictions for each epoch
606
+ wandb.log(
607
+ {
608
+ f"{prefix}/step_{int(step / 1000)}k": wandb.Table(
609
+ columns=["label_str", "pred_str"], data=str_data[:num_log]
610
+ )
611
+ },
612
+ step,
613
+ )
614
+
615
+
616
+ def create_learning_rate_fn(
617
+ num_train_steps: int, num_warmup_steps: int, learning_rate: float
618
+ ) -> Callable[[int], jnp.array]:
619
+ """Returns a linear warmup, linear_decay learning rate function."""
620
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
621
+ decay_fn = optax.linear_schedule(
622
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
623
+ )
624
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
625
+ return schedule_fn
626
+
627
+
628
+ def ctc_loss(
629
+ logits,
630
+ logits_attention_mask,
631
+ labels,
632
+ blank_id,
633
+ loss_reduction="mean",
634
+ output_emission_dict=False,
635
+ log_epsilon=-100000.0,
636
+ ):
637
+ """Computes CTC loss.
638
+ This function performs forward computation over an FSA with `N * 2` states
639
+ where `N` is the max number of labels. The states are split into two groups:
640
+ Phi states and emission states. a phi-state accepts repetition of
641
+ phi (blank)-symbols and transits to emission state when the correct label is
642
+ observed. An emission state accepts repetition of the label and transits to
643
+ the next phi states at any time (so called epsilon-transition).
644
+ Below, `B` denotes the batch size, `T` denotes the time steps in `logits`,
645
+ and `N` denotes the time steps in `labels`.
646
+ Args:
647
+ logits: (B, T, K)-array containing log-probabilities of each class.
648
+ logitpaddings: (B, T)-array. Padding indicators for `logits`.
649
+ labels: (B, N)-array containing reference integer labels.
650
+ labelpaddings: (B, N)-array. Padding indicators for `labels`. Currently,
651
+ `labels` must be right-padded, i.e. each row of `labelpaddings` must be
652
+ repetition of zeroes, followed by repetition of ones.
653
+ blank_id: Id for blank token.
654
+ loss_reduction: one of "mean", "sum", "default"
655
+ - "none": no reduction is applied.
656
+ - "mean": output loss will be divided by target lengths and then the
657
+ mean over the batch is taken.
658
+ - "sum": output loss are summed over batch
659
+ output_emission_dict: whether to output additional information about the emission probs
660
+ Returns:
661
+ A pair of `(per_seq_loss, aux)`.
662
+ per_seq_loss:
663
+ (B,)-array containing loss values for each sequence in the batch.
664
+ aux: Dictionary containing interim variables used for computing losses.
665
+ aux['logalpha_phi']: (T, B, N+1)-array. Log-forward-probabilities of each
666
+ phi-state corresponding to the n-th label.
667
+ aux['logalpha_emit']: (T, B, N)-array. Log-forward-probabilities of each
668
+ emission-state corresponding to the n-th label.
669
+ aux['logprobs_phi']: (T, B, 1)-array. Probability of the phi-symbol
670
+ corresponding to each time frame.
671
+ aux['logprobs_emit']: (T, B, N)-array. Probability of the n-th label
672
+ corresponding to each time frame.
673
+ """
674
+ # label paddings are indicated by -100
675
+ labelpaddings = labels < 0
676
+ # logit paddings are the inverse of attention_mask
677
+ logitpaddings = ~logits_attention_mask
678
+
679
+ # Copied from https://github.com/tensorflow/lingvo/blob/master/lingvo/jax/layers/ctc_objectives.py
680
+ batchsize, unused_maxinputlen, num_classes = logits.shape
681
+ batchsize_, maxlabellen = labels.shape
682
+
683
+ logprobs = jax.nn.log_softmax(logits)
684
+ labellens = maxlabellen - jnp.sum(labelpaddings, axis=1).astype(jnp.int32)
685
+
686
+ # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1].
687
+ repeat = (labels[:, :-1] == labels[:, 1:]).astype(jnp.float32)
688
+ repeat = jnp.pad(repeat, ((0, 0), (0, 1)))
689
+
690
+ logprobs_phi = logprobs[:, :, blank_id : blank_id + 1] # [B, T, 1]
691
+ logprobs_phi = jnp.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1]
692
+
693
+ one_hot = jax.nn.one_hot(labels, num_classes=num_classes) # [B, N, K]
694
+ logprobs_emit = jnp.einsum("btk,bnk->btn", logprobs, one_hot)
695
+ logprobs_emit = jnp.transpose(logprobs_emit, (1, 0, 2)) # [T, B, N]
696
+
697
+ logalpha_phi_init = jnp.ones((batchsize, maxlabellen + 1)) * log_epsilon # [B, N]
698
+ logalpha_phi_init = logalpha_phi_init.at[:, 0].set(0.0)
699
+ logalpha_emit_init = jnp.ones((batchsize, maxlabellen)) * log_epsilon # [B, N]
700
+
701
+ def loop_body(prev, x):
702
+ prev_phi, prev_emit = prev
703
+ # emit-to-phi epsilon transition, except if the next label is repetition
704
+ prev_phi_orig = prev_phi
705
+ prev_phi = prev_phi.at[:, 1:].set(jnp.logaddexp(prev_phi[:, 1:], prev_emit + log_epsilon * repeat))
706
+
707
+ logprob_emit, logprob_phi, pad = x
708
+
709
+ # phi-to-emit transition
710
+ next_emit = jnp.logaddexp(prev_phi[:, :-1] + logprob_emit, prev_emit + logprob_emit)
711
+ # self-loop transition
712
+ next_phi = prev_phi + logprob_phi
713
+ # emit-to-phi blank transition only when the next label is repetition
714
+ next_phi = next_phi.at[:, 1:].set(
715
+ jnp.logaddexp(next_phi[:, 1:], prev_emit + logprob_phi + log_epsilon * (1.0 - repeat))
716
+ )
717
+
718
+ pad = pad.reshape((batchsize, 1))
719
+ next_emit = pad * prev_emit + (1.0 - pad) * next_emit
720
+ next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi
721
+
722
+ return (next_phi, next_emit), (next_phi, next_emit)
723
+
724
+ xs = (logprobs_emit, logprobs_phi, logitpaddings.transpose((1, 0)))
725
+ _, (logalpha_phi, logalpha_emit) = jax.lax.scan(loop_body, (logalpha_phi_init, logalpha_emit_init), xs)
726
+
727
+ # last row needs to be updated with the last epsilon transition
728
+ logalpha_phi_last = logalpha_phi[-1].at[:, 1:].set(jnp.logaddexp(logalpha_phi[-1, :, 1:], logalpha_emit[-1]))
729
+ logalpha_phi = logalpha_phi.at[-1].set(logalpha_phi_last)
730
+
731
+ # extract per_seq_loss
732
+ one_hot = jax.nn.one_hot(labellens, num_classes=maxlabellen + 1) # [B, N+1]
733
+ per_seq_loss = -jnp.einsum("bn,bn->b", logalpha_phi_last, one_hot)
734
+
735
+ if loss_reduction == "mean":
736
+ target_lengths = labelpaddings.shape[-1] - labelpaddings.sum(axis=-1)
737
+ loss = (per_seq_loss / target_lengths).mean()
738
+ elif loss_reduction == "sum":
739
+ loss = per_seq_loss.sum()
740
+ else:
741
+ loss = per_seq_loss
742
+
743
+ if not output_emission_dict:
744
+ return loss
745
+
746
+ return loss, {
747
+ "logalpha_phi": logalpha_phi,
748
+ "logalpha_emit": logalpha_emit,
749
+ "logprobs_phi": logprobs_phi,
750
+ "logprobs_emit": logprobs_emit,
751
+ }
752
+
753
+
754
+ def make_dataset(data_args, seed=42):
755
+ # Pre-processing dataset
756
+ import re
757
+
758
+ def map_nst(entry):
759
+ text = entry["text"].lower()
760
+ text = text.replace("(...vær stille under dette opptaket...)", "")
761
+ text = re.sub('[áàâ]', 'a', text)
762
+ text = re.sub('[ä]', 'æ', text)
763
+ text = re.sub('[éèëê]', 'e', text)
764
+ text = re.sub('[íìïî]', 'i', text)
765
+ text = re.sub('[óòöô]', 'o', text)
766
+ text = re.sub('[ö]', 'ø', text)
767
+ text = re.sub('[ç]', 'c', text)
768
+ text = re.sub('[úùüû]', 'u', text)
769
+ # text = re.sub('\\(?=(Punktum|Komma|Utropstegn|Spørsmålstegn))', ' ', text)
770
+ text = re.sub('\s+', ' ', text)
771
+ return {"text": text}
772
+
773
+ def filter_nst(entry):
774
+ if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)):
775
+ return False # Too short
776
+ if re.match(entry["type"], "pIW|CA"):
777
+ return False # Spelling out words
778
+ return True
779
+
780
+ def filter_npsc(entry):
781
+ # False if there are digits in the text
782
+ if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)):
783
+ return False # Too short
784
+ if re.search("\d", entry["text"]):
785
+ return False
786
+ return True
787
+
788
+ def map_npsc(entry):
789
+ batch = {"text": entry["text"].lower()}
790
+ batch["text"] = re.sub('[áàâ]', 'a', batch["text"])
791
+ batch["text"] = re.sub('[ä]', 'æ', batch["text"])
792
+ batch["text"] = re.sub('[éèëê]', 'e', batch["text"])
793
+ batch["text"] = re.sub('[íìïî]', 'i', batch["text"])
794
+ batch["text"] = re.sub('[óòöô]', 'o', batch["text"])
795
+ batch["text"] = re.sub('[ö]', 'ø', batch["text"])
796
+ batch["text"] = re.sub('[ç]', 'c', batch["text"])
797
+ batch["text"] = re.sub('[úùüû]', 'u', batch["text"])
798
+ batch["text"] = re.sub('\s', ' ', batch["text"])
799
+ batch["text"] = re.sub('<ee>', 'eee', batch["text"])
800
+ batch["text"] = re.sub('<qq>', 'qqq', batch["text"])
801
+ batch["text"] = re.sub('<mm>', 'mmm', batch["text"])
802
+ batch["text"] = re.sub('<inaudible>', 'xxx', batch["text"])
803
+ # batch["text"] = re.sub('<inaudible>', '?', batch["text"])
804
+ if "<" in batch["text"]:
805
+ raise ValueError(batch["text"])
806
+ return batch
807
+
808
+ nst = datasets.load_dataset("NbAiLab/NST", "no-close")
809
+ npsc = datasets.load_dataset("NbAiLab/NPSC", "16K_mp3")
810
+ # TODO NST_hesitate
811
+
812
+ split = len(npsc[data_args.train_split_name]) / (len(npsc[data_args.train_split_name]) + len(npsc[data_args.eval_split_name])) # Use same train/val ratio as NPSC
813
+ nst_train = nst[data_args.train_split_name].train_test_split(train_size=split, seed=seed)
814
+ nst[data_args.train_split_name] = nst_train["train"]
815
+ nst[data_args.eval_split_name] = nst_train["test"]
816
+
817
+ nst = nst.filter(filter_nst).map(
818
+ map_nst,
819
+ num_proc=data_args.preprocessing_num_workers,
820
+ desc="filtering NST",
821
+ ).shuffle(seed=seed)
822
+ npsc = npsc.filter(filter_npsc).map(
823
+ map_npsc,
824
+ num_proc=data_args.preprocessing_num_workers,
825
+ desc="filtering NPSC",
826
+ ).shuffle(seed=seed)
827
+
828
+ npsc_base = npsc.remove_columns([col for col in npsc[data_args.train_split_name].column_names if col not in ["text", "audio"]])
829
+ nst_base = nst.remove_columns([col for col in nst[data_args.train_split_name].column_names if col not in ["text", "audio"]])
830
+
831
+ combined = {}
832
+ for split in data_args.train_split_name, data_args.eval_split_name, data_args.test_split_name:
833
+ probs = np.array([len(nst_base[split]), len(npsc_base[split])]) # Weight by number of examples
834
+ probs = (probs / probs.sum()).tolist()
835
+ comb = datasets.interleave_datasets([nst_base[split], npsc_base[split]], probabilities=probs, seed=seed)
836
+ combined[split] = comb
837
+
838
+ return datasets.DatasetDict(**combined)
839
+
840
+ def main():
841
+ # 1. Parse input arguments
842
+ # See all possible arguments in src/transformers/training_args.py
843
+ # or by passing the --help flag to this script.
844
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
845
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FlaxTrainingArguments))
846
+
847
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
848
+ # If we pass only one argument to the script and it's the path to a json file,
849
+ # let's parse it to get our arguments.
850
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
851
+ else:
852
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
853
+
854
+ # 2. Setup logging
855
+ # Make one log on every process with the configuration for debugging.
856
+ logging.basicConfig(
857
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
858
+ datefmt="%m/%d/%Y %H:%M:%S",
859
+ handlers=[logging.StreamHandler(sys.stdout)],
860
+ )
861
+ # Set the verbosity to info of the Transformers logger.
862
+ # We only want one process per machine to log things on the screen.
863
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
864
+ if jax.process_index() == 0:
865
+ datasets.utils.logging.set_verbosity_warning()
866
+ transformers.utils.logging.set_verbosity_info()
867
+ else:
868
+ datasets.utils.logging.set_verbosity_error()
869
+ transformers.utils.logging.set_verbosity_error()
870
+
871
+ # Set up wandb run
872
+ if jax.process_index() == 0:
873
+ wandb.init(project=data_args.wandb_project, name=data_args.wandb_name, job_type=data_args.wandb_job_type)
874
+
875
+ logger.info("Training/evaluation parameters %s", training_args)
876
+
877
+ # Set the default TPU matmul precision and display the number of devices
878
+ jax.config.update("jax_default_matmul_precision", training_args.matmul_precision)
879
+ logger.info(f"JAX devices: {jax.device_count()}, matmul precision: {training_args.matmul_precision}")
880
+
881
+ # 4. Load dataset
882
+
883
+ set_seed(training_args.seed)
884
+ raw_datasets = make_dataset(data_args, seed=training_args.seed)
885
+
886
+ # raw_datasets = DatasetDict()
887
+
888
+ # if training_args.do_train:
889
+ # raw_datasets[data_args.train_split_name] = load_dataset(
890
+ # data_args.dataset_name,
891
+ # data_args.dataset_config_name,
892
+ # split=data_args.train_split_name,
893
+ # cache_dir=data_args.dataset_cache_dir,
894
+ # use_auth_token=True if model_args.use_auth_token else None,
895
+ # )
896
+
897
+ # if training_args.do_eval:
898
+ # raw_datasets[data_args.eval_split_name] = load_dataset(
899
+ # data_args.dataset_name,
900
+ # data_args.dataset_config_name,
901
+ # split=data_args.eval_split_name,
902
+ # cache_dir=data_args.dataset_cache_dir,
903
+ # use_auth_token=True if model_args.use_auth_token else None,
904
+ # )
905
+
906
+ # if training_args.do_predict:
907
+ # test_split = data_args.test_split_name.split("+")
908
+ # for split in test_split:
909
+ # raw_datasets[split] = load_dataset(
910
+ # data_args.dataset_name,
911
+ # data_args.dataset_config_name,
912
+ # split=split,
913
+ # cache_dir=data_args.dataset_cache_dir,
914
+ # use_auth_token=True if model_args.use_auth_token else None,
915
+ # )
916
+
917
+ if not training_args.do_train and not training_args.do_eval and not training_args.do_predict:
918
+ raise ValueError(
919
+ "Cannot not train, not do evaluation and not do prediction. At least one of "
920
+ "training, evaluation or prediction has to be done."
921
+ )
922
+
923
+ # if not training, there is no need to run multiple epochs
924
+ if not training_args.do_train:
925
+ training_args.num_train_epochs = 1
926
+
927
+ if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
928
+ raise ValueError(
929
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
930
+ "Make sure to set `--audio_column_name` to the correct audio column - one of "
931
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
932
+ )
933
+
934
+ if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names:
935
+ raise ValueError(
936
+ f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
937
+ "Make sure to set `--text_column_name` to the correct text column - one of "
938
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
939
+ )
940
+
941
+ # 5. Load pretrained model, tokenizer, and feature extractor
942
+ #
943
+ # Distributed training:
944
+ # The .from_pretrained methods guarantee that only one local process can concurrently
945
+ config = Wav2Vec2Config.from_pretrained(
946
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
947
+ cache_dir=model_args.cache_dir,
948
+ revision=model_args.model_revision,
949
+ use_auth_token=True if model_args.use_auth_token else None,
950
+ )
951
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
952
+ model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
953
+ cache_dir=model_args.cache_dir,
954
+ revision=model_args.model_revision,
955
+ use_auth_token=True if model_args.use_auth_token else None,
956
+ )
957
+ tokenizer = AutoTokenizer.from_pretrained(
958
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
959
+ cache_dir=model_args.cache_dir,
960
+ revision=model_args.model_revision,
961
+ use_auth_token=True if model_args.use_auth_token else None,
962
+ )
963
+ # update config according to training args, model args, and tokenizer attributes
964
+ config.update(
965
+ {
966
+ "feat_proj_dropout": model_args.feat_proj_dropout,
967
+ "attention_dropout": model_args.attention_dropout,
968
+ "hidden_dropout": model_args.hidden_dropout,
969
+ "final_dropout": model_args.final_dropout,
970
+ "mask_time_prob": model_args.mask_time_prob,
971
+ "mask_time_length": model_args.mask_time_length,
972
+ "mask_feature_prob": model_args.mask_feature_prob,
973
+ "mask_feature_length": model_args.mask_feature_length,
974
+ "gradient_checkpointing": training_args.gradient_checkpointing,
975
+ "layerdrop": model_args.layerdrop,
976
+ "ctc_loss_reduction": model_args.ctc_loss_reduction,
977
+ "ctc_zero_infinity": model_args.ctc_zero_infinity,
978
+ "pad_token_id": tokenizer.pad_token_id,
979
+ "vocab_size": tokenizer.vocab_size, # len(tokenizer),
980
+ "activation_dropout": model_args.activation_dropout,
981
+ }
982
+ )
983
+
984
+ if tokenizer.do_lower_case and data_args.dataset_name != "librispeech_asr":
985
+ raise ValueError(
986
+ "Setting the tokenizer attribute `do_lower_case` to `True` converts all input strings to "
987
+ "uppercase prior to tokenization. This should only be done when the tokenizer is built on an uppercased corpus,"
988
+ "i.e. for the dataset `librispeech_asr` only. If your dataset is not `librispeech_asr`, the tokenizer is mostly likely "
989
+ "built on an lowercased corpus. In this case, set `tokenizer.do_lower_case` to ``False`."
990
+ )
991
+
992
+ if training_args.precision == "full_mixed":
993
+ dtype = jnp.bfloat16
994
+ training_args.mixed_precision = True
995
+ elif training_args.precision == "half_mixed":
996
+ dtype = jnp.bfloat16
997
+ training_args.mixed_precision = False
998
+ else:
999
+ dtype = jnp.float32
1000
+ training_args.mixed_precision = False
1001
+
1002
+ try:
1003
+ model = FlaxWav2Vec2ForCTC.from_pretrained(
1004
+ model_args.model_name_or_path,
1005
+ config=config,
1006
+ dtype=dtype,
1007
+ cache_dir=model_args.cache_dir,
1008
+ revision=model_args.model_revision,
1009
+ use_auth_token=True if model_args.use_auth_token else None,
1010
+ )
1011
+ except:
1012
+ model = FlaxWav2Vec2ForCTC.from_pretrained(
1013
+ model_args.model_name_or_path,
1014
+ config=config,
1015
+ dtype=dtype,
1016
+ cache_dir=model_args.cache_dir,
1017
+ revision=model_args.model_revision,
1018
+ use_auth_token=True if model_args.use_auth_token else None,
1019
+ from_pt=True,
1020
+ )
1021
+
1022
+ # 6. Resample speech dataset ALWAYS
1023
+ raw_datasets = raw_datasets.cast_column(
1024
+ data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
1025
+ )
1026
+
1027
+ # 7. Preprocessing the datasets.
1028
+ # We need to read the audio files as arrays and tokenize the targets.
1029
+ max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
1030
+ min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate)
1031
+ max_target_length = data_args.max_label_length
1032
+ min_target_length = data_args.min_label_length
1033
+ pad_input_to_multiple_of = data_args.pad_input_to_multiple_of
1034
+ audio_column_name = data_args.audio_column_name
1035
+ num_workers = data_args.preprocessing_num_workers
1036
+ text_column_name = data_args.text_column_name
1037
+ model_input_name = feature_extractor.model_input_names[0]
1038
+ do_lower_case = data_args.do_lower_case
1039
+ dataset_name = data_args.dataset_name
1040
+ chars_to_ignore = ', ? . ! - ; : " “ % ‘ ” ?'.split(" ")
1041
+ chars_to_ignore_regex = f'[{"".join(chars_to_ignore)}]'
1042
+ # gigaspeech_punctuation = {" <comma>": ",", " <period>": ".", " <questionmark>": "?", " <exclamationpoint>": "!"}
1043
+ # gigaspeech_disfluencies = ["<other>", "<sil>"]
1044
+ # swb_disfluencies = ["[noise]", "[laughter]", "[silence]", "<a_aside>", "<b_aside>", "<e_aside>", "[laughter-",
1045
+ # "[vocalized-noise]", "_1"]
1046
+ # swb_punctuations = ["{", "}", "[", "]-", "]"]
1047
+ # earnings_disfluencies = ["<crosstalk>", "<affirmative>", "<inaudible>", "inaudible", "<laugh>", "<unk>"]
1048
+ ignore_segments = ["ignore_time_segment_in_scoring", "<noise>", "<music>", "[noise]", "[laughter]", "[silence]",
1049
+ "[vocalized-noise]", "<crosstalk>", "<affirmative>", "<inaudible>", "<laugh>", "<other>", "<sil>", ""]
1050
+
1051
+ if training_args.do_train and data_args.max_train_samples is not None:
1052
+ raw_datasets[data_args.train_split_name] = raw_datasets[data_args.train_split_name].select(range(data_args.max_train_samples))
1053
+
1054
+ if training_args.do_eval and data_args.max_eval_samples is not None:
1055
+ raw_datasets[data_args.eval_split_name] = raw_datasets[data_args.eval_split_name].select(range(data_args.max_eval_samples))
1056
+
1057
+ if training_args.do_predict and data_args.max_test_samples is not None:
1058
+ raw_datasets[data_args.test_split_name] = raw_datasets[data_args.test_split_name].select(range(data_args.max_eval_samples))
1059
+
1060
+ if training_args.do_train and data_args.remove_punctuation:
1061
+
1062
+ def remove_punctuation(batch):
1063
+ batch[text_column_name] = (
1064
+ re.sub(chars_to_ignore_regex, "", batch[text_column_name]).replace("'", "").replace('"', "")
1065
+ )
1066
+
1067
+ raw_datasets[data_args.train_split_name] = raw_datasets[data_args.train_split_name].map(
1068
+ remove_punctuation,
1069
+ num_proc=data_args.preprocessing_num_workers,
1070
+ desc="removing punctuation from train split",
1071
+ )
1072
+
1073
+ # filter data where the targets are ignored in scoring
1074
+ def is_target_labels(input_str):
1075
+ return input_str.lower() not in ignore_segments
1076
+
1077
+ raw_datasets = raw_datasets.filter(
1078
+ is_target_labels,
1079
+ num_proc=num_workers,
1080
+ input_columns=[text_column_name],
1081
+ desc="filtering data where the targets are ignored in scoring",
1082
+ )
1083
+
1084
+ def prepare_dataset(batch):
1085
+ # process audio
1086
+ try:
1087
+ sample = batch[audio_column_name]
1088
+ except ValueError:
1089
+ sample = {"array": np.array([0.]), "sampling_rate": feature_extractor.sampling_rate}
1090
+ inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
1091
+ # process audio length
1092
+ batch[model_input_name] = inputs.input_values[0]
1093
+ batch["input_length"] = len(batch["input_values"])
1094
+
1095
+ # process targets
1096
+ input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
1097
+
1098
+ # if dataset_name == "google/xtreme_s":
1099
+ # # Finally, we tokenize the processed text
1100
+ # batch["labels"] = tokenizer(input_str).input_ids
1101
+ # batch["labels_length"] = len(batch["labels"])
1102
+ # return batch
1103
+
1104
+ # # Common Voice 9
1105
+ # if input_str.startswith('"') and input_str.endswith('"'):
1106
+ # # we can remove trailing quotation marks as they do not affect the transcription
1107
+ # input_str = input_str[1:-1]
1108
+ # # normalize quotation marks
1109
+ # input_str = re.sub(r'["“”]', '"', input_str)
1110
+ # # normalize apostrophes
1111
+ # input_str = re.sub(r"[’']", "'", input_str)
1112
+ # # normalize hyphens
1113
+ # input_str = re.sub(r"[—–]", "-", input_str)
1114
+ # # replace double quotation marks with single
1115
+ # input_str = input_str.replace('""', '"')
1116
+ # if dataset_name == "mozilla-foundation/common_voice_9_0" and len(input_str):
1117
+ # # for CV9, we'll normalize the text to always finish with punctuation
1118
+ # if input_str[-1] not in [".", "?", "!"]:
1119
+ # input_str = input_str + "."
1120
+
1121
+ # # TEDLIUM-3
1122
+ # # delete the <unk> token from the text and replace spaced apostrophes with un-spaced
1123
+ # input_str = input_str.replace("<unk>", "").replace(" '", "'")
1124
+
1125
+ # # GigaSpeech
1126
+ # for disfluency in gigaspeech_disfluencies:
1127
+ # input_str = input_str.replace(disfluency, "")
1128
+ # # convert spelled out punctuation to symbolic form
1129
+ # for punctuation, replacement in gigaspeech_punctuation.items():
1130
+ # input_str = input_str.replace(punctuation, replacement)
1131
+ # if dataset_name == "speechcolab/gigaspeech" and len(input_str):
1132
+ # # for GS, we'll normalize the text to always finish with punctuation
1133
+ # if input_str[-1] not in [".", "?", "!"]:
1134
+ # input_str = input_str + "."
1135
+
1136
+ # # SWB
1137
+ # for disfluency in swb_disfluencies:
1138
+ # input_str = input_str.replace(disfluency, "")
1139
+ # # remove parenthesised text (test data only)
1140
+ # input_str = re.sub("[\(].*?[\)]", "", input_str)
1141
+ # for punctuation in swb_punctuations:
1142
+ # input_str = input_str.replace(punctuation, "")
1143
+ # # replace anomalous words with their correct transcriptions
1144
+ # split_str = input_str.split("/")
1145
+ # if len(split_str) > 1:
1146
+ # input_str = " ".join(
1147
+ # [" ".join([" ".join(i.split(" ")[:-1]) for i in split_str])] + [split_str[-1].split(" ")[-1]])
1148
+
1149
+ # # Earnings 22
1150
+ # for disfluency in earnings_disfluencies:
1151
+ # input_str = input_str.replace(disfluency, "")
1152
+ # # replace mal-formatted ellipsis
1153
+ # input_str = input_str.replace("…", ".")
1154
+
1155
+ # JIWER compliance
1156
+ # remove multiple spaces
1157
+ input_str = re.sub(r"\s\s+", " ", input_str)
1158
+ # strip trailing spaces
1159
+ input_str = input_str.strip()
1160
+
1161
+ # Finally, we tokenize the processed text
1162
+ batch["labels"] = tokenizer(input_str).input_ids
1163
+ batch["labels_length"] = len(batch["labels"])
1164
+ return batch
1165
+
1166
+ vectorized_datasets = raw_datasets.map(
1167
+ prepare_dataset,
1168
+ remove_columns=next(iter(raw_datasets.values())).column_names,
1169
+ num_proc=num_workers,
1170
+ desc="preprocess dataset",
1171
+ )
1172
+
1173
+ # filter data with inputs shorter than min_input_length or longer than max_input_length
1174
+ def is_audio_in_length_range(length):
1175
+ return length > min_input_length and length < max_input_length
1176
+
1177
+ vectorized_datasets = vectorized_datasets.filter(
1178
+ is_audio_in_length_range,
1179
+ num_proc=num_workers,
1180
+ input_columns=["input_length"],
1181
+ )
1182
+
1183
+ # filter data with targets shorter than min_target_length or longer than max_target_length
1184
+ def is_labels_in_length_range(length):
1185
+ return length > min_target_length # and length < max_target_length
1186
+
1187
+ vectorized_datasets = vectorized_datasets.filter(
1188
+ is_labels_in_length_range,
1189
+ num_proc=num_workers,
1190
+ input_columns=["labels_length"],
1191
+ )
1192
+
1193
+ # for large datasets it is advised to run the preprocessing on a
1194
+ # single machine first with `args.preprocessing_only` since there will mostly likely
1195
+ # be a timeout when running the script in distributed mode.
1196
+ # In a second step `args.preprocessing_only` can then be set to `False` to load the
1197
+ # cached dataset
1198
+ if data_args.preprocessing_only:
1199
+ cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
1200
+ logger.info(f"Data preprocessing finished. Files cached at {cache}.")
1201
+ return
1202
+
1203
+ # 8. Load Metrics
1204
+ wer_metric = load_metric("wer")
1205
+ cer_metric = load_metric("cer")
1206
+
1207
+ def compute_metrics(pred_ids: List[List[int]], label_ids: List[List[int]]):
1208
+ padded_ids = np.where(np.asarray(label_ids) == -100, tokenizer.pad_token_id, np.asarray(label_ids))
1209
+
1210
+ pred_str = tokenizer.batch_decode(pred_ids)
1211
+ # we do not want to group tokens when computing the metrics
1212
+ label_str = tokenizer.batch_decode(padded_ids, group_tokens=False)
1213
+
1214
+ wer = wer_metric.compute(predictions=pred_str, references=label_str)
1215
+ cer = cer_metric.compute(predictions=pred_str, references=label_str)
1216
+
1217
+ return {"wer": wer, "cer": cer}, pred_str, label_str
1218
+
1219
+ # 9. save feature extractor, tokenizer and config
1220
+ feature_extractor.save_pretrained(training_args.output_dir)
1221
+ tokenizer.save_pretrained(training_args.output_dir)
1222
+ config.save_pretrained(training_args.output_dir)
1223
+
1224
+ processor = AutoProcessor.from_pretrained(training_args.output_dir)
1225
+
1226
+ data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding(
1227
+ processor=processor,
1228
+ input_padding="longest",
1229
+ pad_input_to_multiple_of=pad_input_to_multiple_of,
1230
+ max_label_length=data_args.max_label_length,
1231
+ )
1232
+
1233
+ # Enable tensorboard only on the master node
1234
+ has_tensorboard = is_tensorboard_available()
1235
+ if has_tensorboard and jax.process_index() == 0:
1236
+ try:
1237
+ from flax.metrics.tensorboard import SummaryWriter
1238
+
1239
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
1240
+ except ImportError as ie:
1241
+ has_tensorboard = False
1242
+ logger.warning(
1243
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
1244
+ )
1245
+ else:
1246
+ logger.warning(
1247
+ "Unable to display metrics through TensorBoard because the package is not installed: "
1248
+ "Please run `pip install tensorboard` to enable."
1249
+ )
1250
+
1251
+ # 10. Handle the repository creation
1252
+ if training_args.push_to_hub:
1253
+ with open(os.path.join(training_args.output_dir, ".gitattributes"), "r+") as f:
1254
+ git_lfs_extensions = f.read()
1255
+ if "*.wandb" not in git_lfs_extensions:
1256
+ f.write("*.wandb filter=lfs diff=lfs merge=lfs -text")
1257
+ if training_args.hub_model_id is None:
1258
+ repo_name = get_full_repo_name(
1259
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
1260
+ )
1261
+ else:
1262
+ repo_name = training_args.hub_model_id
1263
+ repo = Repository(training_args.output_dir, clone_from=repo_name)
1264
+
1265
+ # 11. Initialize our training
1266
+ rng = jax.random.PRNGKey(training_args.seed)
1267
+ rng, dropout_rng = jax.random.split(rng)
1268
+
1269
+ # Store some constants
1270
+ max_steps = int(training_args.max_steps)
1271
+ gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
1272
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
1273
+ batch_size_per_update = train_batch_size * gradient_accumulation_steps
1274
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
1275
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
1276
+ to_dtype = to_bf16 if training_args.mixed_precision else to_fp32
1277
+
1278
+ if training_args.do_train:
1279
+ num_train_samples = len(vectorized_datasets[data_args.train_split_name])
1280
+ steps_per_epoch = num_train_samples // batch_size_per_update
1281
+ if max_steps > 0:
1282
+ num_epochs = -(training_args.max_steps // -steps_per_epoch)
1283
+ total_train_steps = max_steps
1284
+ else:
1285
+ num_epochs = int(training_args.num_train_epochs)
1286
+ total_train_steps = steps_per_epoch * num_epochs
1287
+
1288
+ # Create learning rate schedule
1289
+ # Create learning rate schedule
1290
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
1291
+ total_train_steps,
1292
+ training_args.warmup_steps,
1293
+ training_args.learning_rate,
1294
+ )
1295
+
1296
+ # We use Optax's "masking" functionality to not apply weight decay
1297
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
1298
+ # mask boolean with the same structure as the parameters.
1299
+ # The mask is True for parameters that should be decayed.
1300
+ # Note that this mask is specifically adapted for FlaxWav2Vec2 and FlaxBart.
1301
+ # For FlaxT5, one should correct the layer norm parameter naming
1302
+ # accordingly - see `run_t5_mlm_flax.py` e.g.
1303
+ def decay_mask_fn(params):
1304
+ flat_params = traverse_util.flatten_dict(params)
1305
+ layer_norm_params = [
1306
+ (name, "scale")
1307
+ for name in ["layer_norm", "self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
1308
+ ]
1309
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
1310
+ return traverse_util.unflatten_dict(flat_mask)
1311
+
1312
+ if training_args.adafactor:
1313
+ # Create Adafactor optimizer
1314
+ optim = optax.adafactor(
1315
+ learning_rate=linear_decay_lr_schedule_fn,
1316
+ dtype_momentum=jnp.bfloat16 if training_args.mixed_precision else jnp.float32,
1317
+ weight_decay_rate=training_args.weight_decay,
1318
+ weight_decay_mask=decay_mask_fn,
1319
+ )
1320
+ else:
1321
+ # Create AdamW optimizer
1322
+ optim = optax.adamw(
1323
+ learning_rate=linear_decay_lr_schedule_fn,
1324
+ b1=training_args.adam_beta1,
1325
+ b2=training_args.adam_beta2,
1326
+ eps=training_args.adam_epsilon,
1327
+ weight_decay=training_args.weight_decay,
1328
+ mask=decay_mask_fn,
1329
+ )
1330
+
1331
+ # Optax MultiSteps for gradient accumulation. We'll only call this optimizer transformation if gradient accumulation is required (i.e. gradient accumulation steps > 1)
1332
+ if training_args.multisteps and gradient_accumulation_steps > 1:
1333
+ optim = optax.MultiSteps(optim, gradient_accumulation_steps, use_grad_mean=False)
1334
+ else:
1335
+ num_epochs = 0
1336
+ total_train_steps = 0
1337
+ num_train_samples = 0
1338
+ optim = None
1339
+
1340
+ # Setup train state
1341
+ state = MixedPrecisionTrainState.create(
1342
+ apply_fn=model.__call__,
1343
+ get_attention_mask_fn=model._get_feature_vector_attention_mask,
1344
+ params=model.params,
1345
+ tx=optim,
1346
+ to_dtype=to_dtype,
1347
+ dropout_rng=dropout_rng,
1348
+ max_grad_norm=training_args.max_grad_norm,
1349
+ )
1350
+
1351
+ # Replicate the train state on each device
1352
+ state = state.replicate()
1353
+ blank_id = model.config.pad_token_id
1354
+
1355
+ # Define gradient update step fn
1356
+ def train_step(state, batch):
1357
+ # only one single rng per grad step, with or without accumulation, as the graph should be identical over one effective training batch
1358
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
1359
+
1360
+ def compute_loss(params, minibatch):
1361
+ labels = minibatch.pop("labels")
1362
+ logits = state.apply_fn(
1363
+ **minibatch,
1364
+ params=params,
1365
+ dropout_rng=dropout_rng,
1366
+ freeze_feature_encoder=model_args.freeze_feature_encoder,
1367
+ train=True,
1368
+ )[0]
1369
+ logits_mask = state.get_attention_mask_fn(logits.shape[1], batch["attention_mask"])
1370
+ loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean")
1371
+
1372
+ return loss
1373
+
1374
+ grad_fn = jax.value_and_grad(compute_loss)
1375
+
1376
+ if gradient_accumulation_steps == 1 or training_args.multisteps:
1377
+ loss, grad = grad_fn(to_dtype(state.params), batch)
1378
+
1379
+ # Custom gradient accumulation
1380
+ else:
1381
+ # add a first dimension over gradient_accumulation_steps for minibatch slices
1382
+ batch = jax.tree_util.tree_map(
1383
+ lambda x: x.reshape(
1384
+ gradient_accumulation_steps, training_args.per_device_train_batch_size, *x.shape[1::]
1385
+ ),
1386
+ batch,
1387
+ )
1388
+
1389
+ def accum_minibatch_step(accum_grad, minibatch):
1390
+ # compute loss, num labels and grad over minibatch and accumulate
1391
+ loss, grad = grad_fn(to_dtype(state.params), minibatch)
1392
+ return jax.tree_util.tree_map(jnp.add, accum_grad, grad), loss
1393
+
1394
+ # create an initial state for accumulating losses, num labels and gradients
1395
+ init_grad = jax.tree_util.tree_map(jnp.zeros_like, to_dtype(state.params))
1396
+ # loop accum minibatch step over the number of gradient accumulation steps
1397
+ grad, loss = jax.lax.scan(accum_minibatch_step, init_grad, batch)
1398
+
1399
+ # update state
1400
+ new_state = state.apply_gradients(
1401
+ grads=grad,
1402
+ dropout_rng=new_dropout_rng,
1403
+ to_dtype=to_dtype,
1404
+ )
1405
+
1406
+ # compute gradient norms over all layers and globally for detailed monitoring
1407
+ layer_grad_norm = jax.tree_util.tree_map(jnp.linalg.norm, grad)
1408
+ logs = {
1409
+ "layer_grad_norm": layer_grad_norm,
1410
+ "grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm)),
1411
+ }
1412
+
1413
+ # compute parameter norms over all layers and globally for detailed monitoring
1414
+ layer_param_norm = jax.tree_util.tree_map(jnp.linalg.norm, new_state.params)
1415
+ logs["layer_param_norm"] = layer_param_norm
1416
+ logs["param_norm"] = jnp.linalg.norm(jax.tree_util.tree_leaves(layer_param_norm))
1417
+
1418
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
1419
+ metrics.update(logs)
1420
+
1421
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1422
+ # metrics = to_fp32(metrics)
1423
+
1424
+ return new_state, metrics
1425
+
1426
+ # Define eval fn
1427
+ def eval_step(params, batch):
1428
+ labels = batch.pop("labels")
1429
+ logits = model(**batch, params=params, train=False)[0]
1430
+
1431
+ logits_mask = model._get_feature_vector_attention_mask(logits.shape[1], batch["attention_mask"])
1432
+ loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean")
1433
+
1434
+ pred_ids = jnp.argmax(logits, axis=-1)
1435
+
1436
+ # summarize metrics
1437
+ metrics = {"loss": loss}
1438
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1439
+ # metrics = to_fp32(metrics)
1440
+ return metrics, pred_ids
1441
+
1442
+ # Create parallel version of the train and eval step
1443
+ if training_args.do_train:
1444
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
1445
+
1446
+ if training_args.do_eval:
1447
+ p_eval_step = jax.pmap(eval_step, "batch")
1448
+
1449
+ def run_evaluation(step):
1450
+ if training_args.do_eval:
1451
+ # ======================== Evaluating ==============================
1452
+ eval_metrics = []
1453
+ eval_preds = []
1454
+ eval_labels = []
1455
+
1456
+ # Generate eval set by sequentially sampling indices from the eval dataset and grouping by length
1457
+ eval_samples_idx = get_grouped_indices(vectorized_datasets[data_args.eval_split_name], eval_batch_size)
1458
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
1459
+
1460
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
1461
+ samples = [vectorized_datasets[data_args.eval_split_name][int(idx)] for idx in batch_idx]
1462
+ batch = data_collator(samples)
1463
+ labels = batch["labels"]
1464
+
1465
+ metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1466
+ eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1])))
1467
+ eval_metrics.append(metrics)
1468
+
1469
+ eval_labels.extend(labels)
1470
+
1471
+ # normalize eval metrics
1472
+ eval_metrics = get_metrics(eval_metrics)
1473
+ eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
1474
+ eval_metrics = to_fp32(eval_metrics)
1475
+
1476
+ # always run compute metrics
1477
+ error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels)
1478
+ eval_metrics.update(error_rate_metric)
1479
+ error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()])
1480
+
1481
+ # Print metrics and update progress bar
1482
+ desc = f"Step... ({step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})"
1483
+ epochs.write(desc)
1484
+ epochs.desc = desc
1485
+
1486
+ # Save metrics
1487
+ write_wandb_log(eval_metrics, step, prefix="eval")
1488
+ write_wandb_pred(pred_str, label_str, step)
1489
+ # if has_tensorboard and jax.process_index() == 0:
1490
+ # write_eval_metric(summary_writer, eval_metrics, step, pred_str=pred_str)
1491
+
1492
+ def save_checkpoint(step):
1493
+ # save and push checkpoint to the hub
1494
+ if jax.process_index() == 0:
1495
+ params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
1496
+ model.save_pretrained(training_args.output_dir, params=params)
1497
+ tokenizer.save_pretrained(training_args.output_dir)
1498
+ if training_args.push_to_hub:
1499
+ repo.push_to_hub(commit_message=f"{wandb.run.id}: saving weights and logs of step {int(step / 1000)}k", blocking=False)
1500
+
1501
+ skip_epochs = data_args.skip_steps // (num_train_samples // batch_size_per_update)
1502
+ logger.info("***** Running training *****")
1503
+ logger.info(f" Num examples = {num_train_samples}")
1504
+ logger.info(f" Num Epochs = {num_epochs}")
1505
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
1506
+ logger.info(f" Num gradient accumulation steps = {gradient_accumulation_steps}")
1507
+ logger.info(f" Total train batch size (w. parallel & distributed) = {batch_size_per_update}")
1508
+ logger.info(f" Total optimization steps = {total_train_steps}")
1509
+ logger.info(f" Gradient checkpointing: {config.gradient_checkpointing}")
1510
+ logger.info(f" Use scan: {config.use_scan}")
1511
+ logger.info(f" Fuse matmuls: {config.fuse_matmuls}")
1512
+ logger.info(f" Skipping: {data_args.skip_steps} steps ({skip_epochs} epochs)")
1513
+
1514
+ train_time = cur_step = 0
1515
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
1516
+ for epoch in epochs:
1517
+ if training_args.do_train:
1518
+ # ======================== Training ================================
1519
+ train_start = time.time()
1520
+
1521
+ if epoch < skip_epochs:
1522
+ logger.info(f"Skipping epoch {epoch + 1}")
1523
+ continue
1524
+
1525
+ # Create sampling rng
1526
+ rng, input_rng = jax.random.split(rng)
1527
+
1528
+ # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length
1529
+ train_samples_idx = get_grouped_indices(vectorized_datasets[data_args.train_split_name], batch_size_per_update, input_rng)
1530
+ train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update)
1531
+
1532
+ if data_args.skip_steps > cur_step:
1533
+ logger.info(f"Skipping {data_args.skip_steps - (epoch * (num_train_samples // batch_size_per_update))} steps...")
1534
+ # Gather the indices for creating the batch and do a training step
1535
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1), 1):
1536
+ cur_step = epoch * (num_train_samples // batch_size_per_update) + step
1537
+ if cur_step <= data_args.skip_steps:
1538
+ continue
1539
+
1540
+ samples = [vectorized_datasets[data_args.train_split_name][int(idx)] for idx in batch_idx]
1541
+ batch = data_collator(samples)
1542
+ batch = shard(batch.data)
1543
+ try:
1544
+ state, train_metric = p_train_step(state, batch)
1545
+ except TypeError as e:
1546
+ logger.warning("Encountered following error: \n", e)
1547
+
1548
+
1549
+ if cur_step % training_args.logging_steps == 0:
1550
+ # Save metrics
1551
+ train_metric = unreplicate(train_metric)
1552
+ train_time += time.time() - train_start
1553
+ # need to upcast all device arrays to fp32 for wandb logging (jnp.bfloat16 not supported) -> do this here OR in train_step
1554
+ write_wandb_log(to_fp32(train_metric), cur_step, prefix=data_args.train_split_name)
1555
+ # we won't log to tensorboard for now (it is fiddly logging param and grad norms on a layer-by-layer basis)
1556
+ # if has_tensorboard and jax.process_index() == 0:
1557
+ # write_train_metric(summary_writer, train_metrics, train_time, cur_step)
1558
+
1559
+ epochs.write(
1560
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']}, Gradient Norm: {train_metric['grad_norm']})"
1561
+ )
1562
+
1563
+ if cur_step % total_train_steps == 0:
1564
+ break
1565
+
1566
+ if training_args.eval_steps and cur_step % training_args.eval_steps == 0:
1567
+ run_evaluation(cur_step)
1568
+
1569
+ if cur_step % training_args.save_steps == 0:
1570
+ save_checkpoint(cur_step)
1571
+
1572
+ if training_args.eval_steps == 0 and (epoch + 1) != num_epochs:
1573
+ # run evaluation at the end of the epoch if eval steps are not specified
1574
+ run_evaluation(cur_step)
1575
+ save_checkpoint(cur_step)
1576
+
1577
+ if training_args.do_train:
1578
+ save_checkpoint(cur_step)
1579
+
1580
+ cur_step = max_steps if max_steps > 0 else cur_step # set step to max steps so that eval happens in alignment with training
1581
+
1582
+ if training_args.do_eval:
1583
+ run_evaluation(cur_step)
1584
+
1585
+ # TODO: collapse 'do_predict' into the run_evaluation function
1586
+ if training_args.do_predict:
1587
+ for split in [data_args.test_split_name]:
1588
+ # ======================== Evaluating ==============================
1589
+ eval_metrics = []
1590
+ eval_preds = []
1591
+ eval_labels = []
1592
+
1593
+ # Generate eval set by sequentially sampling indices from the test dataset and grouping by length
1594
+ eval_samples_idx = get_grouped_indices(vectorized_datasets[split], eval_batch_size)
1595
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
1596
+
1597
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc=f"Predicting {split}...", position=2)):
1598
+ samples = [vectorized_datasets[split][int(idx)] for idx in batch_idx]
1599
+ batch = data_collator(samples)
1600
+ labels = batch["labels"]
1601
+
1602
+ metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1603
+ eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1])))
1604
+ eval_metrics.append(metrics)
1605
+
1606
+ eval_labels.extend(labels)
1607
+
1608
+ # normalize eval metrics
1609
+ eval_metrics = get_metrics(eval_metrics)
1610
+ eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
1611
+ eval_metrics = to_fp32(eval_metrics)
1612
+
1613
+ # always run compute metrics
1614
+ error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels)
1615
+ eval_metrics.update(error_rate_metric)
1616
+ error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()])
1617
+
1618
+ # Print metrics and update progress bar
1619
+ desc = f"Step... ({cur_step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})"
1620
+ epochs.write(desc)
1621
+ epochs.desc = desc
1622
+
1623
+ # Save metrics
1624
+ write_wandb_log(eval_metrics, cur_step, prefix=split)
1625
+ write_wandb_pred(pred_str, label_str, cur_step, prefix=split)
1626
+ # if has_tensorboard and jax.process_index() == 0:
1627
+ # write_eval_metric(summary_writer, eval_metrics, cur_step, pred_str=pred_str)
1628
+
1629
+
1630
+ if __name__ == "__main__":
1631
+ main()
wandb/run-20220810_111559-290849gb/files/config.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_version: 1
2
+
3
+ _wandb:
4
+ desc: null
5
+ value:
6
+ cli_version: 0.12.9
7
+ code_path: code/run_flax_speech_recognition_ctc.py
8
+ framework: huggingface
9
+ huggingface_version: 4.21.0
10
+ is_jupyter_run: false
11
+ is_kaggle_kernel: false
12
+ python_version: 3.8.10
13
+ start_time: 1660130159
14
+ t:
15
+ 1:
16
+ - 1
17
+ - 2
18
+ - 3
19
+ - 11
20
+ - 12
21
+ 2:
22
+ - 1
23
+ - 2
24
+ - 3
25
+ - 11
26
+ - 12
27
+ 3:
28
+ - 13
29
+ 4: 3.8.10
30
+ 5: 0.12.9
31
+ 6: 4.21.0
32
+ 8:
33
+ - 5
wandb/run-20220810_111559-290849gb/files/diff.patch ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/special_tokens_map.json b/special_tokens_map.json
2
+ index 218961f..c11fc15 100644
3
+ --- a/special_tokens_map.json
4
+ +++ b/special_tokens_map.json
5
+ @@ -399,6 +399,20 @@
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ + {
10
+ + "content": "</s>",
11
+ + "lstrip": false,
12
+ + "normalized": true,
13
+ + "rstrip": false,
14
+ + "single_word": false
15
+ + },
16
+ + {
17
+ + "content": "<s>",
18
+ + "lstrip": false,
19
+ + "normalized": true,
20
+ + "rstrip": false,
21
+ + "single_word": false
22
+ + },
23
+ {
24
+ "content": "</s>",
25
+ "lstrip": false,
26
+ diff --git a/wandb/debug-internal.log b/wandb/debug-internal.log
27
+ index 23926ef..ad68f93 120000
28
+ --- a/wandb/debug-internal.log
29
+ +++ b/wandb/debug-internal.log
30
+ @@ -1 +1 @@
31
+ -run-20220805_230151-2y71vcu4/logs/debug-internal.log
32
+
33
+ +run-20220810_111559-290849gb/logs/debug-internal.log
34
+
35
+ diff --git a/wandb/debug.log b/wandb/debug.log
36
+ index 279853d..8db277f 120000
37
+ --- a/wandb/debug.log
38
+ +++ b/wandb/debug.log
39
+ @@ -1 +1 @@
40
+ -run-20220805_230151-2y71vcu4/logs/debug.log
41
+
42
+ +run-20220810_111559-290849gb/logs/debug.log
43
+
44
+ diff --git a/wandb/latest-run b/wandb/latest-run
45
+ index f069a7a..052e8bb 120000
46
+ --- a/wandb/latest-run
47
+ +++ b/wandb/latest-run
48
+ @@ -1 +1 @@
49
+ -run-20220805_230151-2y71vcu4
50
+
51
+ +run-20220810_111559-290849gb
52
+
wandb/run-20220810_111559-290849gb/files/output.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5d175f5b339eb2b7e07c06a41cf262a197646be241ea83c5bc595ad9c114374
3
+ size 209075
wandb/run-20220810_111559-290849gb/files/requirements.txt ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.0.0
2
+ aiohttp==3.8.1
3
+ aiosignal==1.2.0
4
+ appdirs==1.4.4
5
+ astunparse==1.6.3
6
+ async-timeout==4.0.2
7
+ attrs==21.4.0
8
+ audioread==2.1.9
9
+ backcall==0.2.0
10
+ cachetools==4.2.4
11
+ certifi==2021.10.8
12
+ cffi==1.15.1
13
+ charset-normalizer==2.0.10
14
+ chex==0.1.3
15
+ click==8.0.3
16
+ cloud-tpu-client==0.10
17
+ cloud-tpu-profiler==2.4.0
18
+ clu==0.0.6
19
+ colorama==0.4.5
20
+ commonmark==0.9.1
21
+ configparser==5.2.0
22
+ contextlib2==21.6.0
23
+ cycler==0.11.0
24
+ datasets==2.4.0
25
+ decorator==5.1.0
26
+ dill==0.3.4
27
+ dm-tree==0.1.6
28
+ docker-pycreds==0.4.0
29
+ etils==0.6.0
30
+ exceptiongroup==1.0.0rc8
31
+ filelock==3.4.2
32
+ flatbuffers==2.0
33
+ flax==0.5.3
34
+ fonttools==4.28.5
35
+ frozenlist==1.2.0
36
+ fsspec==2021.11.1
37
+ future==0.18.2
38
+ gast==0.4.0
39
+ gitdb==4.0.9
40
+ gitpython==3.1.26
41
+ google-api-core==1.31.5
42
+ google-api-python-client==1.8.0
43
+ google-auth-httplib2==0.1.0
44
+ google-auth-oauthlib==0.4.6
45
+ google-auth==2.3.3
46
+ google-pasta==0.2.0
47
+ googleapis-common-protos==1.54.0
48
+ grpcio==1.43.0
49
+ h5py==3.6.0
50
+ httplib2==0.20.2
51
+ huggingface-hub==0.2.1
52
+ hypothesis==6.53.0
53
+ idna==3.3
54
+ importlib-metadata==4.10.0
55
+ importlib-resources==5.4.0
56
+ ipython==7.31.0
57
+ jax==0.3.15
58
+ jaxlib==0.3.15
59
+ jedi==0.18.1
60
+ jiwer==2.3.0
61
+ joblib==1.1.0
62
+ keras-preprocessing==1.1.2
63
+ keras==2.7.0
64
+ kiwisolver==1.3.2
65
+ libclang==12.0.0
66
+ librosa==0.9.2
67
+ libtpu-nightly==0.1.dev20220722
68
+ llvmlite==0.39.0
69
+ markdown==3.3.6
70
+ matplotlib-inline==0.1.3
71
+ matplotlib==3.5.1
72
+ ml-collections==0.1.0
73
+ msgpack==1.0.3
74
+ multidict==5.2.0
75
+ multiprocess==0.70.12.2
76
+ numba==0.56.0
77
+ numpy==1.22.0
78
+ oauth2client==4.1.3
79
+ oauthlib==3.1.1
80
+ opt-einsum==3.3.0
81
+ optax==0.1.3
82
+ packaging==21.3
83
+ pandas==1.3.5
84
+ parso==0.8.3
85
+ pathtools==0.1.2
86
+ pexpect==4.8.0
87
+ pickleshare==0.7.5
88
+ pillow==9.0.0
89
+ pip==22.2.2
90
+ pkg-resources==0.0.0
91
+ pooch==1.6.0
92
+ promise==2.3
93
+ prompt-toolkit==3.0.24
94
+ protobuf==3.19.1
95
+ psutil==5.9.0
96
+ ptyprocess==0.7.0
97
+ pyarrow==6.0.1
98
+ pyasn1-modules==0.2.8
99
+ pyasn1==0.4.8
100
+ pycparser==2.21
101
+ pyctcdecode==0.4.0
102
+ pygments==2.11.1
103
+ pygtrie==2.5.0
104
+ pyparsing==3.0.6
105
+ python-dateutil==2.8.2
106
+ python-levenshtein==0.12.2
107
+ pytz==2021.3
108
+ pyyaml==6.0
109
+ regex==2021.11.10
110
+ requests-oauthlib==1.3.0
111
+ requests==2.27.0
112
+ resampy==0.3.1
113
+ responses==0.18.0
114
+ rich==11.2.0
115
+ rsa==4.8
116
+ sacremoses==0.0.46
117
+ scikit-learn==1.1.1
118
+ scipy==1.7.3
119
+ sentry-sdk==1.5.2
120
+ setuptools==44.0.0
121
+ shortuuid==1.0.8
122
+ six==1.16.0
123
+ smmap==5.0.0
124
+ sortedcontainers==2.4.0
125
+ soundfile==0.10.3.post1
126
+ sox==1.4.1
127
+ subprocess32==3.5.4
128
+ tensorboard-data-server==0.6.1
129
+ tensorboard-plugin-wit==1.8.0
130
+ tensorboard==2.7.0
131
+ tensorflow-cpu==2.7.0
132
+ tensorflow-datasets==4.4.0
133
+ tensorflow-estimator==2.7.0
134
+ tensorflow-io-gcs-filesystem==0.23.1
135
+ tensorflow-metadata==1.5.0
136
+ tensorflow==2.7.0
137
+ tensorstore==0.1.21
138
+ termcolor==1.1.0
139
+ threadpoolctl==3.1.0
140
+ tokenizers==0.11.2
141
+ toolz==0.11.2
142
+ torch==1.12.0
143
+ torchaudio==0.12.0+cpu
144
+ tqdm==4.62.3
145
+ traitlets==5.1.1
146
+ transformers==4.21.0
147
+ typing-extensions==4.3.0
148
+ uritemplate==3.0.1
149
+ urllib3==1.26.7
150
+ wandb==0.12.9
151
+ wcwidth==0.2.5
152
+ werkzeug==2.0.2
153
+ wheel==0.37.1
154
+ wrapt==1.13.3
155
+ xxhash==2.0.2
156
+ yarl==1.7.2
157
+ yaspin==2.1.0
158
+ zipp==3.7.0
wandb/run-20220810_111559-290849gb/files/wandb-metadata.json ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.4.0-1043-gcp-x86_64-with-glibc2.29",
3
+ "python": "3.8.10",
4
+ "heartbeatAt": "2022-08-10T11:16:02.847385",
5
+ "startedAt": "2022-08-10T11:15:59.241818",
6
+ "docker": null,
7
+ "cpu_count": 96,
8
+ "cuda": null,
9
+ "args": [
10
+ "--model_name_or_path=./",
11
+ "--hub_model_id=NbAiLab/wav2vec2-1b-npsc-nst-tpu",
12
+ "--tokenizer_name=./",
13
+ "--output_dir=./",
14
+ "--overwrite_output_dir",
15
+ "--num_train_epochs=40",
16
+ "--per_device_train_batch_size=2",
17
+ "--per_device_eval_batch_size=2",
18
+ "--gradient_accumulation_steps=1",
19
+ "--precision=full_mixed",
20
+ "--matmul_precision=bfloat16",
21
+ "--multisteps",
22
+ "--learning_rate=6.394633237505332e-05",
23
+ "--skip_steps=275000",
24
+ "--warmup_steps=2000",
25
+ "--length_column_name=input_length",
26
+ "--evaluation_strategy=steps",
27
+ "--text_column_name=text",
28
+ "--save_steps=5000",
29
+ "--eval_steps=5000",
30
+ "--logging_steps=100",
31
+ "--layerdrop=0.041",
32
+ "--attention_dropout=0.094",
33
+ "--activation_dropout=0.055",
34
+ "--hidden_dropout=0.047",
35
+ "--save_total_limit=5",
36
+ "--freeze_feature_encoder",
37
+ "--feat_proj_dropout=0.04",
38
+ "--mask_time_prob=0.082",
39
+ "--mask_time_length=10",
40
+ "--mask_feature_prob=0.25",
41
+ "--mask_feature_length=64",
42
+ "--gradient_checkpointing",
43
+ "--min_duration_in_seconds=0.5",
44
+ "--max_duration_in_seconds=30.0",
45
+ "--use_auth_token",
46
+ "--seed=42",
47
+ "--group_by_length",
48
+ "--do_train",
49
+ "--do_eval",
50
+ "--push_to_hub",
51
+ "--preprocessing_num_workers=32",
52
+ "--ctc_zero_infinity",
53
+ "--do_lower_case",
54
+ "--wandb_project=wav2vec2",
55
+ "--wandb_name=wav2vec2-1b-npsc-nst-tpu (cont.)",
56
+ "--remove_punctuation"
57
+ ],
58
+ "state": "running",
59
+ "program": "run_flax_speech_recognition_ctc.py",
60
+ "codePath": "run_flax_speech_recognition_ctc.py",
61
+ "git": {
62
+ "remote": "https://huggingface.co/NbAiLab/wav2vec2-1b-npsc-nst-tpu",
63
+ "commit": "f624ac4bfedfbf56891676d7c5f2e37b4c8e0745"
64
+ },
65
+ "email": "versae@gmail.com",
66
+ "root": "/data/wav2vec2-1b-npsc-nst-tpu",
67
+ "host": "t1v-n-eedfb410-w-0",
68
+ "username": "javierr",
69
+ "executable": "/data/flax/bin/python"
70
+ }
wandb/run-20220810_111559-290849gb/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"train/grad_norm": 6.5625, "layer_grad_norm/": {"lm_head": {"bias": 0.031982421875, "kernel": 4.625}, "wav2vec2": {"encoder": {"layer_norm": {"bias": 0.0556640625, "scale": 0.06103515625}, "layers": {"0": {"attention": {"k_proj": {"bias": 5.9604644775390625e-05, "kernel": 0.023681640625}, "out_proj": {"bias": 0.04150390625, "kernel": 0.2431640625}, "q_proj": {"bias": 0.002899169921875, "kernel": 0.031005859375}, "v_proj": {"bias": 0.037109375, "kernel": 0.265625}}, "feed_forward": {"intermediate_dense": {"bias": 0.04443359375, "kernel": 0.515625}, "output_dense": {"bias": 0.0159912109375, "kernel": 0.439453125}}, "final_layer_norm": {"bias": 0.146484375, "scale": 0.322265625}, "layer_norm": {"bias": 0.0703125, "scale": 0.07080078125}}, "1": {"attention": {"k_proj": {"bias": 3.4332275390625e-05, "kernel": 0.03955078125}, "out_proj": {"bias": 0.0177001953125, "kernel": 0.134765625}, "q_proj": {"bias": 0.0035247802734375, "kernel": 0.0439453125}, "v_proj": {"bias": 0.02880859375, "kernel": 0.111328125}}, "feed_forward": {"intermediate_dense": {"bias": 0.022216796875, "kernel": 0.3359375}, "output_dense": {"bias": 0.0157470703125, "kernel": 0.259765625}}, "final_layer_norm": {"bias": 0.046142578125, "scale": 0.05712890625}, "layer_norm": {"bias": 0.05712890625, "scale": 0.039794921875}}, "10": {"attention": {"k_proj": {"bias": 3.600120544433594e-05, "kernel": 0.123046875}, "out_proj": {"bias": 0.01416015625, "kernel": 0.2001953125}, "q_proj": {"bias": 0.0078125, "kernel": 0.12255859375}, "v_proj": {"bias": 0.022705078125, "kernel": 0.2001953125}}, "feed_forward": {"intermediate_dense": {"bias": 0.021484375, "kernel": 0.328125}, "output_dense": {"bias": 0.013671875, "kernel": 0.2734375}}, "final_layer_norm": {"bias": 0.0390625, "scale": 0.0322265625}, "layer_norm": {"bias": 0.04931640625, "scale": 0.0341796875}}, "11": {"attention": {"k_proj": {"bias": 8.344650268554688e-05, "kernel": 0.158203125}, "out_proj": {"bias": 0.0142822265625, "kernel": 0.28125}, "q_proj": {"bias": 0.0087890625, "kernel": 0.130859375}, "v_proj": {"bias": 0.024658203125, "kernel": 0.28515625}}, "feed_forward": {"intermediate_dense": {"bias": 0.01953125, "kernel": 0.310546875}, "output_dense": {"bias": 0.013916015625, "kernel": 0.244140625}}, "final_layer_norm": {"bias": 0.03271484375, "scale": 0.0308837890625}, "layer_norm": {"bias": 0.05029296875, "scale": 0.0439453125}}, "12": {"attention": {"k_proj": {"bias": 4.982948303222656e-05, "kernel": 0.123046875}, "out_proj": {"bias": 0.0147705078125, "kernel": 0.244140625}, "q_proj": {"bias": 0.0081787109375, "kernel": 0.1162109375}, "v_proj": {"bias": 0.023681640625, "kernel": 0.2294921875}}, "feed_forward": {"intermediate_dense": {"bias": 0.021484375, "kernel": 0.32421875}, "output_dense": {"bias": 0.014404296875, "kernel": 0.255859375}}, "final_layer_norm": {"bias": 0.0390625, "scale": 0.04248046875}, "layer_norm": {"bias": 0.046630859375, "scale": 0.0546875}}, "13": {"attention": {"k_proj": {"bias": 0.00012493133544921875, "kernel": 0.15625}, "out_proj": {"bias": 0.01519775390625, "kernel": 0.330078125}, "q_proj": {"bias": 0.0111083984375, "kernel": 0.158203125}, "v_proj": {"bias": 0.026611328125, "kernel": 0.3046875}}, "feed_forward": {"intermediate_dense": {"bias": 0.022216796875, "kernel": 0.34375}, "output_dense": {"bias": 0.01513671875, "kernel": 0.3125}}, "final_layer_norm": {"bias": 0.040283203125, "scale": 0.032958984375}, "layer_norm": {"bias": 0.051513671875, "scale": 0.091796875}}, "14": {"attention": {"k_proj": {"bias": 9.441375732421875e-05, "kernel": 0.1005859375}, "out_proj": {"bias": 0.015625, "kernel": 0.2412109375}, "q_proj": {"bias": 0.006256103515625, "kernel": 0.099609375}, "v_proj": {"bias": 0.0235595703125, "kernel": 0.2275390625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0257568359375, "kernel": 0.39453125}, "output_dense": {"bias": 0.015380859375, "kernel": 0.33984375}}, "final_layer_norm": {"bias": 0.05126953125, "scale": 0.05517578125}, "layer_norm": {"bias": 0.041748046875, "scale": 0.03076171875}}, "15": {"attention": {"k_proj": {"bias": 0.0003070831298828125, "kernel": 0.1806640625}, "out_proj": {"bias": 0.015625, "kernel": 0.5078125}, "q_proj": {"bias": 0.0106201171875, "kernel": 0.173828125}, "v_proj": {"bias": 0.026611328125, "kernel": 0.361328125}}, "feed_forward": {"intermediate_dense": {"bias": 0.024169921875, "kernel": 0.376953125}, "output_dense": {"bias": 0.01556396484375, "kernel": 0.349609375}}, "final_layer_norm": {"bias": 0.041015625, "scale": 0.033447265625}, "layer_norm": {"bias": 0.048095703125, "scale": 0.072265625}}, "16": {"attention": {"k_proj": {"bias": 6.389617919921875e-05, "kernel": 0.1025390625}, "out_proj": {"bias": 0.016357421875, "kernel": 0.267578125}, "q_proj": {"bias": 0.0057373046875, "kernel": 0.1005859375}, "v_proj": {"bias": 0.0238037109375, "kernel": 0.220703125}}, "feed_forward": {"intermediate_dense": {"bias": 0.0223388671875, "kernel": 0.359375}, "output_dense": {"bias": 0.0159912109375, "kernel": 0.341796875}}, "final_layer_norm": {"bias": 0.0380859375, "scale": 0.033935546875}, "layer_norm": {"bias": 0.043212890625, "scale": 0.034912109375}}, "17": {"attention": {"k_proj": {"bias": 4.57763671875e-05, "kernel": 0.0927734375}, "out_proj": {"bias": 0.0172119140625, "kernel": 0.23046875}, "q_proj": {"bias": 0.005889892578125, "kernel": 0.087890625}, "v_proj": {"bias": 0.0244140625, "kernel": 0.2177734375}}, "feed_forward": {"intermediate_dense": {"bias": 0.024169921875, "kernel": 0.390625}, "output_dense": {"bias": 0.01708984375, "kernel": 0.353515625}}, "final_layer_norm": {"bias": 0.041259765625, "scale": 0.036376953125}, "layer_norm": {"bias": 0.0439453125, "scale": 0.0341796875}}, "18": {"attention": {"k_proj": {"bias": 0.000247955322265625, "kernel": 0.126953125}, "out_proj": {"bias": 0.017578125, "kernel": 0.369140625}, "q_proj": {"bias": 0.0076904296875, "kernel": 0.1337890625}, "v_proj": {"bias": 0.027587890625, "kernel": 0.298828125}}, "feed_forward": {"intermediate_dense": {"bias": 0.02587890625, "kernel": 0.44921875}, "output_dense": {"bias": 0.0174560546875, "kernel": 0.41015625}}, "final_layer_norm": {"bias": 0.04443359375, "scale": 0.03857421875}, "layer_norm": {"bias": 0.048583984375, "scale": 0.039794921875}}, "19": {"attention": {"k_proj": {"bias": 8.678436279296875e-05, "kernel": 0.140625}, "out_proj": {"bias": 0.017822265625, "kernel": 0.28125}, "q_proj": {"bias": 0.009033203125, "kernel": 0.140625}, "v_proj": {"bias": 0.0286865234375, "kernel": 0.283203125}}, "feed_forward": {"intermediate_dense": {"bias": 0.02587890625, "kernel": 0.474609375}, "output_dense": {"bias": 0.0174560546875, "kernel": 0.421875}}, "final_layer_norm": {"bias": 0.041748046875, "scale": 0.0380859375}, "layer_norm": {"bias": 0.052734375, "scale": 0.04052734375}}, "2": {"attention": {"k_proj": {"bias": 4.982948303222656e-05, "kernel": 0.07421875}, "out_proj": {"bias": 0.0177001953125, "kernel": 0.2060546875}, "q_proj": {"bias": 0.006195068359375, "kernel": 0.06982421875}, "v_proj": {"bias": 0.03173828125, "kernel": 0.181640625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0234375, "kernel": 0.390625}, "output_dense": {"bias": 0.01556396484375, "kernel": 0.275390625}}, "final_layer_norm": {"bias": 0.047119140625, "scale": 0.03173828125}, "layer_norm": {"bias": 0.0556640625, "scale": 0.07275390625}}, "20": {"attention": {"k_proj": {"bias": 2.110004425048828e-05, "kernel": 0.095703125}, "out_proj": {"bias": 0.0185546875, "kernel": 0.142578125}, "q_proj": {"bias": 0.005157470703125, "kernel": 0.0947265625}, "v_proj": {"bias": 0.0263671875, "kernel": 0.140625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0250244140625, "kernel": 0.4765625}, "output_dense": {"bias": 0.018310546875, "kernel": 0.390625}}, "final_layer_norm": {"bias": 0.041015625, "scale": 0.04443359375}, "layer_norm": {"bias": 0.046142578125, "scale": 0.038330078125}}, "21": {"attention": {"k_proj": {"bias": 4.00543212890625e-05, "kernel": 0.1259765625}, "out_proj": {"bias": 0.0189208984375, "kernel": 0.2216796875}, "q_proj": {"bias": 0.006927490234375, "kernel": 0.12890625}, "v_proj": {"bias": 0.02734375, "kernel": 0.203125}}, "feed_forward": {"intermediate_dense": {"bias": 0.0267333984375, "kernel": 0.51953125}, "output_dense": {"bias": 0.0185546875, "kernel": 0.41796875}}, "final_layer_norm": {"bias": 0.04541015625, "scale": 0.04736328125}, "layer_norm": {"bias": 0.044189453125, "scale": 0.054443359375}}, "22": {"attention": {"k_proj": {"bias": 3.3855438232421875e-05, "kernel": 0.1181640625}, "out_proj": {"bias": 0.019775390625, "kernel": 0.240234375}, "q_proj": {"bias": 0.006011962890625, "kernel": 0.11279296875}, "v_proj": {"bias": 0.028076171875, "kernel": 0.21875}}, "feed_forward": {"intermediate_dense": {"bias": 0.0269775390625, "kernel": 0.515625}, "output_dense": {"bias": 0.0194091796875, "kernel": 0.400390625}}, "final_layer_norm": {"bias": 0.046142578125, "scale": 0.047119140625}, "layer_norm": {"bias": 0.049560546875, "scale": 0.0458984375}}, "23": {"attention": {"k_proj": {"bias": 0.0001087188720703125, "kernel": 0.16015625}, "out_proj": {"bias": 0.0198974609375, "kernel": 0.443359375}, "q_proj": {"bias": 0.008544921875, "kernel": 0.1630859375}, "v_proj": {"bias": 0.03173828125, "kernel": 0.35546875}}, "feed_forward": {"intermediate_dense": {"bias": 0.0263671875, "kernel": 0.53125}, "output_dense": {"bias": 0.01953125, "kernel": 0.400390625}}, "final_layer_norm": {"bias": 0.044677734375, "scale": 0.04638671875}, "layer_norm": {"bias": 0.05615234375, "scale": 0.056396484375}}, "24": {"attention": {"k_proj": {"bias": 6.246566772460938e-05, "kernel": 0.169921875}, "out_proj": {"bias": 0.0191650390625, "kernel": 0.36328125}, "q_proj": {"bias": 0.00933837890625, "kernel": 0.18359375}, "v_proj": {"bias": 0.03271484375, "kernel": 0.328125}}, "feed_forward": {"intermediate_dense": {"bias": 0.02685546875, "kernel": 0.5390625}, "output_dense": {"bias": 0.01904296875, "kernel": 0.37890625}}, "final_layer_norm": {"bias": 0.04736328125, "scale": 0.04345703125}, "layer_norm": {"bias": 0.0625, "scale": 0.041015625}}, "25": {"attention": {"k_proj": {"bias": 6.079673767089844e-05, "kernel": 0.169921875}, "out_proj": {"bias": 0.0196533203125, "kernel": 0.3125}, "q_proj": {"bias": 0.00860595703125, "kernel": 0.16015625}, "v_proj": {"bias": 0.03271484375, "kernel": 0.32421875}}, "feed_forward": {"intermediate_dense": {"bias": 0.02734375, "kernel": 0.55859375}, "output_dense": {"bias": 0.01953125, "kernel": 0.375}}, "final_layer_norm": {"bias": 0.050537109375, "scale": 0.0478515625}, "layer_norm": {"bias": 0.06005859375, "scale": 0.06298828125}}, "26": {"attention": {"k_proj": {"bias": 5.9604644775390625e-05, "kernel": 0.169921875}, "out_proj": {"bias": 0.01953125, "kernel": 0.29296875}, "q_proj": {"bias": 0.01025390625, "kernel": 0.177734375}, "v_proj": {"bias": 0.0341796875, "kernel": 0.296875}}, "feed_forward": {"intermediate_dense": {"bias": 0.026611328125, "kernel": 0.51171875}, "output_dense": {"bias": 0.01904296875, "kernel": 0.353515625}}, "final_layer_norm": {"bias": 0.0478515625, "scale": 0.04443359375}, "layer_norm": {"bias": 0.060791015625, "scale": 0.06396484375}}, "27": {"attention": {"k_proj": {"bias": 0.00011396408081054688, "kernel": 0.2021484375}, "out_proj": {"bias": 0.01806640625, "kernel": 0.44921875}, "q_proj": {"bias": 0.01068115234375, "kernel": 0.2138671875}, "v_proj": {"bias": 0.03466796875, "kernel": 0.435546875}}, "feed_forward": {"intermediate_dense": {"bias": 0.02734375, "kernel": 0.515625}, "output_dense": {"bias": 0.0181884765625, "kernel": 0.36328125}}, "final_layer_norm": {"bias": 0.05078125, "scale": 0.045654296875}, "layer_norm": {"bias": 0.06640625, "scale": 0.04931640625}}, "28": {"attention": {"k_proj": {"bias": 0.0001049041748046875, "kernel": 0.20703125}, "out_proj": {"bias": 0.0164794921875, "kernel": 0.392578125}, "q_proj": {"bias": 0.01165771484375, "kernel": 0.208984375}, "v_proj": {"bias": 0.031494140625, "kernel": 0.404296875}}, "feed_forward": {"intermediate_dense": {"bias": 0.0234375, "kernel": 0.45703125}, "output_dense": {"bias": 0.016357421875, "kernel": 0.326171875}}, "final_layer_norm": {"bias": 0.04248046875, "scale": 0.044921875}, "layer_norm": {"bias": 0.0673828125, "scale": 0.08447265625}}, "29": {"attention": {"k_proj": {"bias": 9.918212890625e-05, "kernel": 0.267578125}, "out_proj": {"bias": 0.0157470703125, "kernel": 0.28515625}, "q_proj": {"bias": 0.01495361328125, "kernel": 0.265625}, "v_proj": {"bias": 0.02978515625, "kernel": 0.31640625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0234375, "kernel": 0.494140625}, "output_dense": {"bias": 0.01531982421875, "kernel": 0.296875}}, "final_layer_norm": {"bias": 0.03955078125, "scale": 0.03515625}, "layer_norm": {"bias": 0.0654296875, "scale": 0.061279296875}}, "3": {"attention": {"k_proj": {"bias": 0.00012111663818359375, "kernel": 0.0986328125}, "out_proj": {"bias": 0.016845703125, "kernel": 0.314453125}, "q_proj": {"bias": 0.00726318359375, "kernel": 0.0888671875}, "v_proj": {"bias": 0.0283203125, "kernel": 0.2470703125}}, "feed_forward": {"intermediate_dense": {"bias": 0.0242919921875, "kernel": 0.3828125}, "output_dense": {"bias": 0.0150146484375, "kernel": 0.271484375}}, "final_layer_norm": {"bias": 0.0458984375, "scale": 0.03125}, "layer_norm": {"bias": 0.0498046875, "scale": 0.0380859375}}, "30": {"attention": {"k_proj": {"bias": 0.0001220703125, "kernel": 0.13671875}, "out_proj": {"bias": 0.0152587890625, "kernel": 0.328125}, "q_proj": {"bias": 0.006378173828125, "kernel": 0.138671875}, "v_proj": {"bias": 0.029296875, "kernel": 0.3671875}}, "feed_forward": {"intermediate_dense": {"bias": 0.023681640625, "kernel": 0.51953125}, "output_dense": {"bias": 0.01446533203125, "kernel": 0.26171875}}, "final_layer_norm": {"bias": 0.04150390625, "scale": 0.03564453125}, "layer_norm": {"bias": 0.04931640625, "scale": 0.037109375}}, "31": {"attention": {"k_proj": {"bias": 0.00010347366333007812, "kernel": 0.14453125}, "out_proj": {"bias": 0.0140380859375, "kernel": 0.29296875}, "q_proj": {"bias": 0.006378173828125, "kernel": 0.134765625}, "v_proj": {"bias": 0.0238037109375, "kernel": 0.314453125}}, "feed_forward": {"intermediate_dense": {"bias": 0.02392578125, "kernel": 0.51953125}, "output_dense": {"bias": 0.01385498046875, "kernel": 0.2578125}}, "final_layer_norm": {"bias": 0.0390625, "scale": 0.03662109375}, "layer_norm": {"bias": 0.039306640625, "scale": 0.0291748046875}}, "32": {"attention": {"k_proj": {"bias": 8.296966552734375e-05, "kernel": 0.15625}, "out_proj": {"bias": 0.01263427734375, "kernel": 0.28125}, "q_proj": {"bias": 0.0079345703125, "kernel": 0.1533203125}, "v_proj": {"bias": 0.0264892578125, "kernel": 0.4921875}}, "feed_forward": {"intermediate_dense": {"bias": 0.0216064453125, "kernel": 0.431640625}, "output_dense": {"bias": 0.01129150390625, "kernel": 0.212890625}}, "final_layer_norm": {"bias": 0.04150390625, "scale": 0.03271484375}, "layer_norm": {"bias": 0.046630859375, "scale": 0.05419921875}}, "33": {"attention": {"k_proj": {"bias": 5.9604644775390625e-05, "kernel": 0.166015625}, "out_proj": {"bias": 0.01092529296875, "kernel": 0.2275390625}, "q_proj": {"bias": 0.008544921875, "kernel": 0.166015625}, "v_proj": {"bias": 0.023193359375, "kernel": 0.34765625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0196533203125, "kernel": 0.390625}, "output_dense": {"bias": 0.00897216796875, "kernel": 0.1875}}, "final_layer_norm": {"bias": 0.04345703125, "scale": 0.0361328125}, "layer_norm": {"bias": 0.039794921875, "scale": 0.0498046875}}, "34": {"attention": {"k_proj": {"bias": 0.0002346038818359375, "kernel": 0.158203125}, "out_proj": {"bias": 0.0081787109375, "kernel": 0.181640625}, "q_proj": {"bias": 0.006927490234375, "kernel": 0.14453125}, "v_proj": {"bias": 0.0177001953125, "kernel": 0.25390625}}, "feed_forward": {"intermediate_dense": {"bias": 0.01434326171875, "kernel": 0.291015625}, "output_dense": {"bias": 0.0072021484375, "kernel": 0.1748046875}}, "final_layer_norm": {"bias": 0.028076171875, "scale": 0.025146484375}, "layer_norm": {"bias": 0.03369140625, "scale": 0.026611328125}}, "35": {"attention": {"k_proj": {"bias": 0.0001506805419921875, "kernel": 0.10791015625}, "out_proj": {"bias": 0.00640869140625, "kernel": 0.2109375}, "q_proj": {"bias": 0.004852294921875, "kernel": 0.10791015625}, "v_proj": {"bias": 0.01177978515625, "kernel": 0.21484375}}, "feed_forward": {"intermediate_dense": {"bias": 0.010498046875, "kernel": 0.2119140625}, "output_dense": {"bias": 0.005889892578125, "kernel": 0.15234375}}, "final_layer_norm": {"bias": 0.0206298828125, "scale": 0.0220947265625}, "layer_norm": {"bias": 0.024169921875, "scale": 0.02880859375}}, "36": {"attention": {"k_proj": {"bias": 4.410743713378906e-05, "kernel": 0.1005859375}, "out_proj": {"bias": 0.005645751953125, "kernel": 0.1552734375}, "q_proj": {"bias": 0.00445556640625, "kernel": 0.095703125}, "v_proj": {"bias": 0.00946044921875, "kernel": 0.14453125}}, "feed_forward": {"intermediate_dense": {"bias": 0.0089111328125, "kernel": 0.177734375}, "output_dense": {"bias": 0.0050048828125, "kernel": 0.111328125}}, "final_layer_norm": {"bias": 0.017578125, "scale": 0.01513671875}, "layer_norm": {"bias": 0.0191650390625, "scale": 0.01806640625}}, "37": {"attention": {"k_proj": {"bias": 9.441375732421875e-05, "kernel": 0.0849609375}, "out_proj": {"bias": 0.004913330078125, "kernel": 0.11474609375}, "q_proj": {"bias": 0.00390625, "kernel": 0.0830078125}, "v_proj": {"bias": 0.00897216796875, "kernel": 0.1318359375}}, "feed_forward": {"intermediate_dense": {"bias": 0.00823974609375, "kernel": 0.16796875}, "output_dense": {"bias": 0.004241943359375, "kernel": 0.09716796875}}, "final_layer_norm": {"bias": 0.015869140625, "scale": 0.01434326171875}, "layer_norm": {"bias": 0.019287109375, "scale": 0.015869140625}}, "38": {"attention": {"k_proj": {"bias": 5.650520324707031e-05, "kernel": 0.09130859375}, "out_proj": {"bias": 0.0040283203125, "kernel": 0.11865234375}, "q_proj": {"bias": 0.00396728515625, "kernel": 0.08642578125}, "v_proj": {"bias": 0.007354736328125, "kernel": 0.1279296875}}, "feed_forward": {"intermediate_dense": {"bias": 0.0072021484375, "kernel": 0.150390625}, "output_dense": {"bias": 0.0034637451171875, "kernel": 0.09423828125}}, "final_layer_norm": {"bias": 0.0152587890625, "scale": 0.0146484375}, "layer_norm": {"bias": 0.0162353515625, "scale": 0.0135498046875}}, "39": {"attention": {"k_proj": {"bias": 5.316734313964844e-05, "kernel": 0.09619140625}, "out_proj": {"bias": 0.0030975341796875, "kernel": 0.09619140625}, "q_proj": {"bias": 0.00408935546875, "kernel": 0.0908203125}, "v_proj": {"bias": 0.006011962890625, "kernel": 0.10986328125}}, "feed_forward": {"intermediate_dense": {"bias": 0.005401611328125, "kernel": 0.12109375}, "output_dense": {"bias": 0.0025634765625, "kernel": 0.08642578125}}, "final_layer_norm": {"bias": 0.01202392578125, "scale": 0.01226806640625}, "layer_norm": {"bias": 0.0150146484375, "scale": 0.01556396484375}}, "4": {"attention": {"k_proj": {"bias": 0.000148773193359375, "kernel": 0.10498046875}, "out_proj": {"bias": 0.015869140625, "kernel": 0.361328125}, "q_proj": {"bias": 0.0072021484375, "kernel": 0.1005859375}, "v_proj": {"bias": 0.026123046875, "kernel": 0.3046875}}, "feed_forward": {"intermediate_dense": {"bias": 0.023193359375, "kernel": 0.36328125}, "output_dense": {"bias": 0.014404296875, "kernel": 0.29296875}}, "final_layer_norm": {"bias": 0.042724609375, "scale": 0.034423828125}, "layer_norm": {"bias": 0.0478515625, "scale": 0.060546875}}, "40": {"attention": {"k_proj": {"bias": 5.269050598144531e-05, "kernel": 0.046875}, "out_proj": {"bias": 0.0025787353515625, "kernel": 0.080078125}, "q_proj": {"bias": 0.0020294189453125, "kernel": 0.0458984375}, "v_proj": {"bias": 0.004150390625, "kernel": 0.07080078125}}, "feed_forward": {"intermediate_dense": {"bias": 0.004302978515625, "kernel": 0.09326171875}, "output_dense": {"bias": 0.0023040771484375, "kernel": 0.060791015625}}, "final_layer_norm": {"bias": 0.0087890625, "scale": 0.011474609375}, "layer_norm": {"bias": 0.00823974609375, "scale": 0.007781982421875}}, "41": {"attention": {"k_proj": {"bias": 4.0531158447265625e-05, "kernel": 0.0673828125}, "out_proj": {"bias": 0.002044677734375, "kernel": 0.087890625}, "q_proj": {"bias": 0.0025787353515625, "kernel": 0.0634765625}, "v_proj": {"bias": 0.00439453125, "kernel": 0.1044921875}}, "feed_forward": {"intermediate_dense": {"bias": 0.0035858154296875, "kernel": 0.091796875}, "output_dense": {"bias": 0.00168609619140625, "kernel": 0.064453125}}, "final_layer_norm": {"bias": 0.00921630859375, "scale": 0.0106201171875}, "layer_norm": {"bias": 0.010986328125, "scale": 0.0128173828125}}, "42": {"attention": {"k_proj": {"bias": 1.1801719665527344e-05, "kernel": 0.02099609375}, "out_proj": {"bias": 0.001678466796875, "kernel": 0.048828125}, "q_proj": {"bias": 0.0009307861328125, "kernel": 0.02197265625}, "v_proj": {"bias": 0.002349853515625, "kernel": 0.0478515625}}, "feed_forward": {"intermediate_dense": {"bias": 0.002685546875, "kernel": 0.07421875}, "output_dense": {"bias": 0.0014801025390625, "kernel": 0.053955078125}}, "final_layer_norm": {"bias": 0.0059814453125, "scale": 0.00885009765625}, "layer_norm": {"bias": 0.0045166015625, "scale": 0.00439453125}}, "43": {"attention": {"k_proj": {"bias": 8.046627044677734e-06, "kernel": 0.01806640625}, "out_proj": {"bias": 0.0015106201171875, "kernel": 0.035888671875}, "q_proj": {"bias": 0.00095367431640625, "kernel": 0.0208740234375}, "v_proj": {"bias": 0.00171661376953125, "kernel": 0.031005859375}}, "feed_forward": {"intermediate_dense": {"bias": 0.002777099609375, "kernel": 0.083984375}, "output_dense": {"bias": 0.00125885009765625, "kernel": 0.052734375}}, "final_layer_norm": {"bias": 0.00689697265625, "scale": 0.007110595703125}, "layer_norm": {"bias": 0.0036468505859375, "scale": 0.005706787109375}}, "44": {"attention": {"k_proj": {"bias": 1.3113021850585938e-05, "kernel": 0.023681640625}, "out_proj": {"bias": 0.001312255859375, "kernel": 0.033447265625}, "q_proj": {"bias": 0.000946044921875, "kernel": 0.021484375}, "v_proj": {"bias": 0.0017547607421875, "kernel": 0.03515625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0023956298828125, "kernel": 0.0771484375}, "output_dense": {"bias": 0.001129150390625, "kernel": 0.052001953125}}, "final_layer_norm": {"bias": 0.005706787109375, "scale": 0.005859375}, "layer_norm": {"bias": 0.0042724609375, "scale": 0.004730224609375}}, "45": {"attention": {"k_proj": {"bias": 1.4424324035644531e-05, "kernel": 0.01239013671875}, "out_proj": {"bias": 0.00110626220703125, "kernel": 0.0267333984375}, "q_proj": {"bias": 0.00136566162109375, "kernel": 0.029541015625}, "v_proj": {"bias": 0.00142669677734375, "kernel": 0.027587890625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0019989013671875, "kernel": 0.06396484375}, "output_dense": {"bias": 0.000881195068359375, "kernel": 0.05615234375}}, "final_layer_norm": {"bias": 0.006072998046875, "scale": 0.006591796875}, "layer_norm": {"bias": 0.004791259765625, "scale": 0.00494384765625}}, "46": {"attention": {"k_proj": {"bias": 5.745887756347656e-05, "kernel": 0.006439208984375}, "out_proj": {"bias": 0.000888824462890625, "kernel": 0.0289306640625}, "q_proj": {"bias": 0.000591278076171875, "kernel": 0.011962890625}, "v_proj": {"bias": 0.0010986328125, "kernel": 0.0233154296875}}, "feed_forward": {"intermediate_dense": {"bias": 0.00141143798828125, "kernel": 0.03955078125}, "output_dense": {"bias": 0.000873565673828125, "kernel": 0.0478515625}}, "final_layer_norm": {"bias": 0.00433349609375, "scale": 0.00433349609375}, "layer_norm": {"bias": 0.0036468505859375, "scale": 0.003814697265625}}, "47": {"attention": {"k_proj": {"bias": 0.00011301040649414062, "kernel": 0.003997802734375}, "out_proj": {"bias": 0.000896453857421875, "kernel": 0.06640625}, "q_proj": {"bias": 0.00014591217041015625, "kernel": 0.00286865234375}, "v_proj": {"bias": 0.00118255615234375, "kernel": 0.0230712890625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0010528564453125, "kernel": 0.0252685546875}, "output_dense": {"bias": 0.000881195068359375, "kernel": 0.1787109375}}, "final_layer_norm": {"bias": 0.005950927734375, "scale": 0.00677490234375}, "layer_norm": {"bias": 0.005859375, "scale": 0.005767822265625}}, "5": {"attention": {"k_proj": {"bias": 6.4849853515625e-05, "kernel": 0.123046875}, "out_proj": {"bias": 0.0159912109375, "kernel": 0.203125}, "q_proj": {"bias": 0.007598876953125, "kernel": 0.12158203125}, "v_proj": {"bias": 0.02685546875, "kernel": 0.1953125}}, "feed_forward": {"intermediate_dense": {"bias": 0.023193359375, "kernel": 0.33984375}, "output_dense": {"bias": 0.0147705078125, "kernel": 0.275390625}}, "final_layer_norm": {"bias": 0.041015625, "scale": 0.037841796875}, "layer_norm": {"bias": 0.052734375, "scale": 0.04736328125}}, "6": {"attention": {"k_proj": {"bias": 7.152557373046875e-05, "kernel": 0.1318359375}, "out_proj": {"bias": 0.0152587890625, "kernel": 0.349609375}, "q_proj": {"bias": 0.00823974609375, "kernel": 0.119140625}, "v_proj": {"bias": 0.02685546875, "kernel": 0.31640625}}, "feed_forward": {"intermediate_dense": {"bias": 0.022216796875, "kernel": 0.341796875}, "output_dense": {"bias": 0.014404296875, "kernel": 0.271484375}}, "final_layer_norm": {"bias": 0.0380859375, "scale": 0.0322265625}, "layer_norm": {"bias": 0.049560546875, "scale": 0.058349609375}}, "7": {"attention": {"k_proj": {"bias": 7.62939453125e-05, "kernel": 0.1328125}, "out_proj": {"bias": 0.0150146484375, "kernel": 0.349609375}, "q_proj": {"bias": 0.00921630859375, "kernel": 0.126953125}, "v_proj": {"bias": 0.0255126953125, "kernel": 0.30859375}}, "feed_forward": {"intermediate_dense": {"bias": 0.021484375, "kernel": 0.33984375}, "output_dense": {"bias": 0.01409912109375, "kernel": 0.26171875}}, "final_layer_norm": {"bias": 0.038818359375, "scale": 0.033935546875}, "layer_norm": {"bias": 0.05126953125, "scale": 0.050048828125}}, "8": {"attention": {"k_proj": {"bias": 7.915496826171875e-05, "kernel": 0.12060546875}, "out_proj": {"bias": 0.01507568359375, "kernel": 0.302734375}, "q_proj": {"bias": 0.007568359375, "kernel": 0.11474609375}, "v_proj": {"bias": 0.0262451171875, "kernel": 0.27734375}}, "feed_forward": {"intermediate_dense": {"bias": 0.023193359375, "kernel": 0.361328125}, "output_dense": {"bias": 0.01409912109375, "kernel": 0.275390625}}, "final_layer_norm": {"bias": 0.044677734375, "scale": 0.033447265625}, "layer_norm": {"bias": 0.0498046875, "scale": 0.06103515625}}, "9": {"attention": {"k_proj": {"bias": 0.00011777877807617188, "kernel": 0.1513671875}, "out_proj": {"bias": 0.0140380859375, "kernel": 0.416015625}, "q_proj": {"bias": 0.00830078125, "kernel": 0.1376953125}, "v_proj": {"bias": 0.0238037109375, "kernel": 0.40234375}}, "feed_forward": {"intermediate_dense": {"bias": 0.0213623046875, "kernel": 0.3515625}, "output_dense": {"bias": 0.0135498046875, "kernel": 0.279296875}}, "final_layer_norm": {"bias": 0.037109375, "scale": 0.0322265625}, "layer_norm": {"bias": 0.04443359375, "scale": 0.04833984375}}}, "pos_conv_embed": {"conv": {"bias": 0.034912109375, "weight_g": 0.044189453125, "weight_v": 0.287109375}}}, "feature_extractor": {"conv_layers": {"0": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "1": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "2": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "3": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "4": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "5": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "6": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}}}, "feature_projection": {"layer_norm": {"bias": 0.1337890625, "scale": 0.1611328125}, "projection": {"bias": 0.0556640625, "kernel": 1.0546875}}, "masked_spec_embed": 0.0}}, "layer_param_norm/": {"lm_head": {"bias": 0.7824921607971191, "kernel": 55.72966766357422}, "wav2vec2": {"encoder": {"layer_norm": {"bias": 59.6768798828125, "scale": 74.17054748535156}, "layers": {"0": {"attention": {"k_proj": {"bias": 0.37033993005752563, "kernel": 27.536663055419922}, "out_proj": {"bias": 1.6469175815582275, "kernel": 26.147050857543945}, "q_proj": {"bias": 1.5330281257629395, "kernel": 27.813282012939453}, "v_proj": {"bias": 0.44783300161361694, "kernel": 26.55841064453125}}, "feed_forward": {"intermediate_dense": {"bias": 1.9835489988327026, "kernel": 100.66567993164062}, "output_dense": {"bias": 1.116748571395874, "kernel": 96.76679992675781}}, "final_layer_norm": {"bias": 1.335214376449585, "scale": 19.85782241821289}, "layer_norm": {"bias": 2.923041343688965, "scale": 15.398418426513672}}, "1": {"attention": {"k_proj": {"bias": 0.3767347037792206, "kernel": 41.013240814208984}, "out_proj": {"bias": 1.3653483390808105, "kernel": 43.371070861816406}, "q_proj": {"bias": 3.0925614833831787, "kernel": 41.05661392211914}, "v_proj": {"bias": 0.2924947738647461, "kernel": 41.61189270019531}}, "feed_forward": {"intermediate_dense": {"bias": 1.9882662296295166, "kernel": 98.7442626953125}, "output_dense": {"bias": 0.8527815341949463, "kernel": 87.83541870117188}}, "final_layer_norm": {"bias": 1.3605934381484985, "scale": 19.084806442260742}, "layer_norm": {"bias": 1.9318764209747314, "scale": 17.761367797851562}}, "10": {"attention": {"k_proj": {"bias": 0.4123449921607971, "kernel": 49.44670486450195}, "out_proj": {"bias": 1.3130683898925781, "kernel": 52.27025604248047}, "q_proj": {"bias": 2.48445200920105, "kernel": 49.528873443603516}, "v_proj": {"bias": 0.3416975736618042, "kernel": 52.344085693359375}}, "feed_forward": {"intermediate_dense": {"bias": 1.974550485610962, "kernel": 102.70410919189453}, "output_dense": {"bias": 0.5955485105514526, "kernel": 95.81275939941406}}, "final_layer_norm": {"bias": 2.3762400150299072, "scale": 20.81279754638672}, "layer_norm": {"bias": 1.806241512298584, "scale": 21.429487228393555}}, "11": {"attention": {"k_proj": {"bias": 0.4460787773132324, "kernel": 49.37338638305664}, "out_proj": {"bias": 1.1512463092803955, "kernel": 51.949554443359375}, "q_proj": {"bias": 2.5446064472198486, "kernel": 49.20353698730469}, "v_proj": {"bias": 0.40995872020721436, "kernel": 52.2607536315918}}, "feed_forward": {"intermediate_dense": {"bias": 2.019528388977051, "kernel": 103.56025695800781}, "output_dense": {"bias": 0.5711302757263184, "kernel": 97.53776550292969}}, "final_layer_norm": {"bias": 2.3660812377929688, "scale": 20.919677734375}, "layer_norm": {"bias": 1.7802445888519287, "scale": 22.01519203186035}}, "12": {"attention": {"k_proj": {"bias": 0.4312320351600647, "kernel": 50.07032775878906}, "out_proj": {"bias": 1.126122236251831, "kernel": 51.988826751708984}, "q_proj": {"bias": 2.4090728759765625, "kernel": 49.92729949951172}, "v_proj": {"bias": 0.4024103879928589, "kernel": 52.296756744384766}}, "feed_forward": {"intermediate_dense": {"bias": 2.0548508167266846, "kernel": 104.54823303222656}, "output_dense": {"bias": 0.5548778772354126, "kernel": 99.2693099975586}}, "final_layer_norm": {"bias": 2.2933573722839355, "scale": 20.85626983642578}, "layer_norm": {"bias": 1.8587299585342407, "scale": 22.473487854003906}}, "13": {"attention": {"k_proj": {"bias": 0.4430793821811676, "kernel": 51.76431655883789}, "out_proj": {"bias": 1.1271920204162598, "kernel": 51.86852264404297}, "q_proj": {"bias": 2.359200954437256, "kernel": 51.75225830078125}, "v_proj": {"bias": 0.3906242251396179, "kernel": 51.92781066894531}}, "feed_forward": {"intermediate_dense": {"bias": 2.0941619873046875, "kernel": 105.30806732177734}, "output_dense": {"bias": 0.5719542503356934, "kernel": 99.8712158203125}}, "final_layer_norm": {"bias": 2.2314066886901855, "scale": 21.027400970458984}, "layer_norm": {"bias": 1.9997800588607788, "scale": 22.84510040283203}}, "14": {"attention": {"k_proj": {"bias": 0.43604815006256104, "kernel": 51.92181396484375}, "out_proj": {"bias": 1.2681217193603516, "kernel": 49.762760162353516}, "q_proj": {"bias": 2.4942922592163086, "kernel": 52.049808502197266}, "v_proj": {"bias": 0.36820662021636963, "kernel": 49.26283264160156}}, "feed_forward": {"intermediate_dense": {"bias": 2.1335830688476562, "kernel": 105.94457244873047}, "output_dense": {"bias": 0.6063626408576965, "kernel": 101.24859619140625}}, "final_layer_norm": {"bias": 2.2754664421081543, "scale": 21.145992279052734}, "layer_norm": {"bias": 2.1295526027679443, "scale": 22.584672927856445}}, "15": {"attention": {"k_proj": {"bias": 0.45931702852249146, "kernel": 51.93058776855469}, "out_proj": {"bias": 1.3753983974456787, "kernel": 50.94449234008789}, "q_proj": {"bias": 2.5871663093566895, "kernel": 52.099769592285156}, "v_proj": {"bias": 0.459650456905365, "kernel": 50.5831298828125}}, "feed_forward": {"intermediate_dense": {"bias": 2.132938861846924, "kernel": 105.57443237304688}, "output_dense": {"bias": 0.7701732516288757, "kernel": 101.94094848632812}}, "final_layer_norm": {"bias": 2.327320098876953, "scale": 21.192947387695312}, "layer_norm": {"bias": 2.4148712158203125, "scale": 23.526634216308594}}, "16": {"attention": {"k_proj": {"bias": 0.4008745551109314, "kernel": 51.772621154785156}, "out_proj": {"bias": 1.27531099319458, "kernel": 50.134521484375}, "q_proj": {"bias": 2.667466163635254, "kernel": 51.75814437866211}, "v_proj": {"bias": 0.3768249750137329, "kernel": 49.78093719482422}}, "feed_forward": {"intermediate_dense": {"bias": 2.1094985008239746, "kernel": 106.10227966308594}, "output_dense": {"bias": 0.7860437631607056, "kernel": 102.66590881347656}}, "final_layer_norm": {"bias": 2.337951421737671, "scale": 21.583194732666016}, "layer_norm": {"bias": 2.283249855041504, "scale": 22.168060302734375}}, "17": {"attention": {"k_proj": {"bias": 0.3966267704963684, "kernel": 51.728885650634766}, "out_proj": {"bias": 1.2151354551315308, "kernel": 49.4556884765625}, "q_proj": {"bias": 2.714320182800293, "kernel": 51.81880187988281}, "v_proj": {"bias": 0.42661017179489136, "kernel": 49.11927032470703}}, "feed_forward": {"intermediate_dense": {"bias": 2.1101765632629395, "kernel": 107.14872741699219}, "output_dense": {"bias": 0.8218655586242676, "kernel": 103.06423950195312}}, "final_layer_norm": {"bias": 2.383938789367676, "scale": 22.070323944091797}, "layer_norm": {"bias": 2.222898483276367, "scale": 21.219982147216797}}, "18": {"attention": {"k_proj": {"bias": 0.4409676194190979, "kernel": 52.41611099243164}, "out_proj": {"bias": 1.3447906970977783, "kernel": 50.491905212402344}, "q_proj": {"bias": 2.614685535430908, "kernel": 52.796600341796875}, "v_proj": {"bias": 0.4518332779407501, "kernel": 50.001895904541016}}, "feed_forward": {"intermediate_dense": {"bias": 2.144195556640625, "kernel": 107.41338348388672}, "output_dense": {"bias": 0.9481453895568848, "kernel": 104.72514343261719}}, "final_layer_norm": {"bias": 2.5390124320983887, "scale": 22.15178680419922}, "layer_norm": {"bias": 2.424910068511963, "scale": 23.585906982421875}}, "19": {"attention": {"k_proj": {"bias": 0.38193291425704956, "kernel": 51.511146545410156}, "out_proj": {"bias": 1.3303101062774658, "kernel": 50.10035705566406}, "q_proj": {"bias": 2.930327892303467, "kernel": 51.865638732910156}, "v_proj": {"bias": 0.4086824655532837, "kernel": 49.38078308105469}}, "feed_forward": {"intermediate_dense": {"bias": 2.1912901401519775, "kernel": 107.95254516601562}, "output_dense": {"bias": 1.0248571634292603, "kernel": 105.65098571777344}}, "final_layer_norm": {"bias": 2.4923481941223145, "scale": 22.505674362182617}, "layer_norm": {"bias": 2.2888314723968506, "scale": 22.31826400756836}}, "2": {"attention": {"k_proj": {"bias": 0.454792320728302, "kernel": 47.77275085449219}, "out_proj": {"bias": 1.256988525390625, "kernel": 45.969764709472656}, "q_proj": {"bias": 3.2510807514190674, "kernel": 47.61664581298828}, "v_proj": {"bias": 0.339598685503006, "kernel": 45.72273254394531}}, "feed_forward": {"intermediate_dense": {"bias": 1.9737317562103271, "kernel": 103.32754516601562}, "output_dense": {"bias": 0.7398276329040527, "kernel": 91.11263275146484}}, "final_layer_norm": {"bias": 1.5421981811523438, "scale": 21.561111450195312}, "layer_norm": {"bias": 1.7081801891326904, "scale": 20.852447509765625}}, "20": {"attention": {"k_proj": {"bias": 0.4067543148994446, "kernel": 51.605438232421875}, "out_proj": {"bias": 1.359946370124817, "kernel": 49.45553207397461}, "q_proj": {"bias": 2.8498687744140625, "kernel": 52.224571228027344}, "v_proj": {"bias": 0.36227869987487793, "kernel": 48.43864822387695}}, "feed_forward": {"intermediate_dense": {"bias": 2.1725549697875977, "kernel": 109.17405700683594}, "output_dense": {"bias": 1.1388803720474243, "kernel": 106.40528106689453}}, "final_layer_norm": {"bias": 2.435314655303955, "scale": 23.4317626953125}, "layer_norm": {"bias": 2.231672525405884, "scale": 22.230525970458984}}, "21": {"attention": {"k_proj": {"bias": 0.4161534905433655, "kernel": 51.942527770996094}, "out_proj": {"bias": 1.403618335723877, "kernel": 49.51059341430664}, "q_proj": {"bias": 2.7690629959106445, "kernel": 52.67078399658203}, "v_proj": {"bias": 0.41060006618499756, "kernel": 48.64883041381836}}, "feed_forward": {"intermediate_dense": {"bias": 2.2174296379089355, "kernel": 109.5155029296875}, "output_dense": {"bias": 1.253208041191101, "kernel": 106.88243865966797}}, "final_layer_norm": {"bias": 2.4632763862609863, "scale": 23.175764083862305}, "layer_norm": {"bias": 2.2785892486572266, "scale": 22.234222412109375}}, "22": {"attention": {"k_proj": {"bias": 0.45357397198677063, "kernel": 52.54576110839844}, "out_proj": {"bias": 1.349219560623169, "kernel": 49.533172607421875}, "q_proj": {"bias": 2.8105549812316895, "kernel": 52.86981201171875}, "v_proj": {"bias": 0.3973655700683594, "kernel": 49.33363342285156}}, "feed_forward": {"intermediate_dense": {"bias": 2.1619315147399902, "kernel": 109.95498657226562}, "output_dense": {"bias": 1.3076066970825195, "kernel": 106.3852310180664}}, "final_layer_norm": {"bias": 2.3642821311950684, "scale": 22.684059143066406}, "layer_norm": {"bias": 2.3316237926483154, "scale": 21.545879364013672}}, "23": {"attention": {"k_proj": {"bias": 0.4928613007068634, "kernel": 53.47669219970703}, "out_proj": {"bias": 1.564335823059082, "kernel": 50.98707580566406}, "q_proj": {"bias": 2.7065773010253906, "kernel": 53.582611083984375}, "v_proj": {"bias": 0.5810648202896118, "kernel": 51.54853820800781}}, "feed_forward": {"intermediate_dense": {"bias": 2.131969690322876, "kernel": 109.86410522460938}, "output_dense": {"bias": 1.2769315242767334, "kernel": 107.37890625}}, "final_layer_norm": {"bias": 2.767916679382324, "scale": 22.887813568115234}, "layer_norm": {"bias": 2.824352264404297, "scale": 23.373172760009766}}, "24": {"attention": {"k_proj": {"bias": 0.46056002378463745, "kernel": 52.424072265625}, "out_proj": {"bias": 1.6070430278778076, "kernel": 52.50334167480469}, "q_proj": {"bias": 2.828113079071045, "kernel": 52.40515899658203}, "v_proj": {"bias": 0.5424190163612366, "kernel": 52.51116180419922}}, "feed_forward": {"intermediate_dense": {"bias": 2.2367913722991943, "kernel": 109.35035705566406}, "output_dense": {"bias": 1.3016372919082642, "kernel": 110.30095672607422}}, "final_layer_norm": {"bias": 2.83841872215271, "scale": 22.964658737182617}, "layer_norm": {"bias": 2.56215763092041, "scale": 22.983924865722656}}, "25": {"attention": {"k_proj": {"bias": 0.42509031295776367, "kernel": 52.730464935302734}, "out_proj": {"bias": 1.363797664642334, "kernel": 50.5806884765625}, "q_proj": {"bias": 2.9342763423919678, "kernel": 52.548744201660156}, "v_proj": {"bias": 0.6404213309288025, "kernel": 51.0885009765625}}, "feed_forward": {"intermediate_dense": {"bias": 2.1367578506469727, "kernel": 109.70021057128906}, "output_dense": {"bias": 1.1017413139343262, "kernel": 110.27072143554688}}, "final_layer_norm": {"bias": 2.5763301849365234, "scale": 23.494670867919922}, "layer_norm": {"bias": 2.683134078979492, "scale": 21.88357925415039}}, "26": {"attention": {"k_proj": {"bias": 0.4836847186088562, "kernel": 53.01764678955078}, "out_proj": {"bias": 1.2433912754058838, "kernel": 51.37077331542969}, "q_proj": {"bias": 2.943906784057617, "kernel": 52.80891036987305}, "v_proj": {"bias": 0.5064959526062012, "kernel": 52.004638671875}}, "feed_forward": {"intermediate_dense": {"bias": 2.2763516902923584, "kernel": 109.44652557373047}, "output_dense": {"bias": 1.0912110805511475, "kernel": 107.40899658203125}}, "final_layer_norm": {"bias": 2.1937994956970215, "scale": 22.433353424072266}, "layer_norm": {"bias": 2.497119903564453, "scale": 22.19057273864746}}, "27": {"attention": {"k_proj": {"bias": 0.5808594226837158, "kernel": 53.76898956298828}, "out_proj": {"bias": 1.5447406768798828, "kernel": 52.95805358886719}, "q_proj": {"bias": 2.703345775604248, "kernel": 53.69578552246094}, "v_proj": {"bias": 0.6748642325401306, "kernel": 53.388118743896484}}, "feed_forward": {"intermediate_dense": {"bias": 2.404933452606201, "kernel": 107.8713150024414}, "output_dense": {"bias": 0.9485896825790405, "kernel": 107.17198181152344}}, "final_layer_norm": {"bias": 2.5252954959869385, "scale": 21.88959503173828}, "layer_norm": {"bias": 2.6147172451019287, "scale": 23.32440948486328}}, "28": {"attention": {"k_proj": {"bias": 0.5901432037353516, "kernel": 54.482521057128906}, "out_proj": {"bias": 1.5367379188537598, "kernel": 53.31493377685547}, "q_proj": {"bias": 2.9472482204437256, "kernel": 54.1741943359375}, "v_proj": {"bias": 0.5131911039352417, "kernel": 53.759761810302734}}, "feed_forward": {"intermediate_dense": {"bias": 2.3475265502929688, "kernel": 107.87416076660156}, "output_dense": {"bias": 0.8224154710769653, "kernel": 109.1680908203125}}, "final_layer_norm": {"bias": 2.425306797027588, "scale": 22.337677001953125}, "layer_norm": {"bias": 2.0914058685302734, "scale": 23.993711471557617}}, "29": {"attention": {"k_proj": {"bias": 0.46781182289123535, "kernel": 51.12034606933594}, "out_proj": {"bias": 1.5021522045135498, "kernel": 55.685630798339844}, "q_proj": {"bias": 2.809702157974243, "kernel": 51.00274658203125}, "v_proj": {"bias": 0.4760415554046631, "kernel": 55.703304290771484}}, "feed_forward": {"intermediate_dense": {"bias": 2.297222137451172, "kernel": 108.01033020019531}, "output_dense": {"bias": 0.9597339630126953, "kernel": 113.12825012207031}}, "final_layer_norm": {"bias": 2.5980498790740967, "scale": 23.459980010986328}, "layer_norm": {"bias": 2.245180130004883, "scale": 25.39927864074707}}, "3": {"attention": {"k_proj": {"bias": 0.45006245374679565, "kernel": 52.03215789794922}, "out_proj": {"bias": 1.4254932403564453, "kernel": 48.60858917236328}, "q_proj": {"bias": 2.8560738563537598, "kernel": 52.312644958496094}, "v_proj": {"bias": 0.3246268630027771, "kernel": 48.768699645996094}}, "feed_forward": {"intermediate_dense": {"bias": 1.9663825035095215, "kernel": 104.83622741699219}, "output_dense": {"bias": 0.6984099745750427, "kernel": 94.07957458496094}}, "final_layer_norm": {"bias": 1.8095453977584839, "scale": 21.664737701416016}, "layer_norm": {"bias": 1.9017157554626465, "scale": 22.739452362060547}}, "30": {"attention": {"k_proj": {"bias": 0.5024805665016174, "kernel": 52.825706481933594}, "out_proj": {"bias": 1.3023658990859985, "kernel": 52.053871154785156}, "q_proj": {"bias": 2.907101631164551, "kernel": 52.91836166381836}, "v_proj": {"bias": 0.49308842420578003, "kernel": 52.49382019042969}}, "feed_forward": {"intermediate_dense": {"bias": 2.2399911880493164, "kernel": 108.17861938476562}, "output_dense": {"bias": 0.9140658378601074, "kernel": 112.09104919433594}}, "final_layer_norm": {"bias": 2.4926414489746094, "scale": 24.492368698120117}, "layer_norm": {"bias": 2.316732168197632, "scale": 24.931156158447266}}, "31": {"attention": {"k_proj": {"bias": 0.5412741899490356, "kernel": 51.240806579589844}, "out_proj": {"bias": 1.2333163022994995, "kernel": 52.19988250732422}, "q_proj": {"bias": 2.6581294536590576, "kernel": 51.346168518066406}, "v_proj": {"bias": 0.5469827651977539, "kernel": 52.432586669921875}}, "feed_forward": {"intermediate_dense": {"bias": 2.3097352981567383, "kernel": 106.72758483886719}, "output_dense": {"bias": 1.0891624689102173, "kernel": 109.24717712402344}}, "final_layer_norm": {"bias": 2.2962756156921387, "scale": 24.31252670288086}, "layer_norm": {"bias": 2.3430848121643066, "scale": 24.590187072753906}}, "32": {"attention": {"k_proj": {"bias": 0.4704548716545105, "kernel": 50.39933776855469}, "out_proj": {"bias": 1.2453913688659668, "kernel": 51.57465362548828}, "q_proj": {"bias": 2.8450098037719727, "kernel": 50.34847640991211}, "v_proj": {"bias": 0.419519305229187, "kernel": 51.972320556640625}}, "feed_forward": {"intermediate_dense": {"bias": 2.2569355964660645, "kernel": 105.33018493652344}, "output_dense": {"bias": 1.146787166595459, "kernel": 108.35574340820312}}, "final_layer_norm": {"bias": 2.31538724899292, "scale": 24.518985748291016}, "layer_norm": {"bias": 2.417579174041748, "scale": 24.991830825805664}}, "33": {"attention": {"k_proj": {"bias": 0.48390907049179077, "kernel": 50.28266143798828}, "out_proj": {"bias": 1.280959129333496, "kernel": 51.30557632446289}, "q_proj": {"bias": 2.998173713684082, "kernel": 50.253868103027344}, "v_proj": {"bias": 0.4416005611419678, "kernel": 51.71035385131836}}, "feed_forward": {"intermediate_dense": {"bias": 2.279946804046631, "kernel": 103.67684173583984}, "output_dense": {"bias": 1.1754591464996338, "kernel": 106.81501007080078}}, "final_layer_norm": {"bias": 2.257889747619629, "scale": 24.207740783691406}, "layer_norm": {"bias": 2.5865607261657715, "scale": 25.067882537841797}}, "34": {"attention": {"k_proj": {"bias": 0.45390456914901733, "kernel": 49.26523208618164}, "out_proj": {"bias": 1.5265402793884277, "kernel": 52.46200942993164}, "q_proj": {"bias": 2.913527011871338, "kernel": 49.268951416015625}, "v_proj": {"bias": 0.4023621678352356, "kernel": 52.534793853759766}}, "feed_forward": {"intermediate_dense": {"bias": 2.3745017051696777, "kernel": 102.21138000488281}, "output_dense": {"bias": 1.1227428913116455, "kernel": 105.72161102294922}}, "final_layer_norm": {"bias": 2.20273494720459, "scale": 23.640857696533203}, "layer_norm": {"bias": 2.615731716156006, "scale": 25.498104095458984}}, "35": {"attention": {"k_proj": {"bias": 0.5336894989013672, "kernel": 51.04902648925781}, "out_proj": {"bias": 1.4900906085968018, "kernel": 51.15006637573242}, "q_proj": {"bias": 2.573650360107422, "kernel": 51.323951721191406}, "v_proj": {"bias": 0.4889468550682068, "kernel": 51.176795959472656}}, "feed_forward": {"intermediate_dense": {"bias": 2.502547264099121, "kernel": 100.7535400390625}, "output_dense": {"bias": 1.0254027843475342, "kernel": 104.25007629394531}}, "final_layer_norm": {"bias": 2.295243263244629, "scale": 23.61981964111328}, "layer_norm": {"bias": 2.499863862991333, "scale": 26.093618392944336}}, "36": {"attention": {"k_proj": {"bias": 0.4473738670349121, "kernel": 48.319740295410156}, "out_proj": {"bias": 1.5177178382873535, "kernel": 52.26478958129883}, "q_proj": {"bias": 2.6124308109283447, "kernel": 48.236427307128906}, "v_proj": {"bias": 0.39316776394844055, "kernel": 52.66499328613281}}, "feed_forward": {"intermediate_dense": {"bias": 2.3669564723968506, "kernel": 99.60517120361328}, "output_dense": {"bias": 1.0260361433029175, "kernel": 103.68067932128906}}, "final_layer_norm": {"bias": 2.044203519821167, "scale": 24.13540267944336}, "layer_norm": {"bias": 2.2894434928894043, "scale": 25.63806915283203}}, "37": {"attention": {"k_proj": {"bias": 0.6240901350975037, "kernel": 47.300960540771484}, "out_proj": {"bias": 1.7604304552078247, "kernel": 52.189971923828125}, "q_proj": {"bias": 2.3819239139556885, "kernel": 47.316253662109375}, "v_proj": {"bias": 0.38518577814102173, "kernel": 52.32656478881836}}, "feed_forward": {"intermediate_dense": {"bias": 2.2694783210754395, "kernel": 98.56558227539062}, "output_dense": {"bias": 1.0087021589279175, "kernel": 103.13001251220703}}, "final_layer_norm": {"bias": 1.7903382778167725, "scale": 24.500085830688477}, "layer_norm": {"bias": 2.238807439804077, "scale": 25.563133239746094}}, "38": {"attention": {"k_proj": {"bias": 0.722891092300415, "kernel": 45.45281219482422}, "out_proj": {"bias": 1.4463956356048584, "kernel": 51.495086669921875}, "q_proj": {"bias": 2.2622861862182617, "kernel": 45.454437255859375}, "v_proj": {"bias": 0.42936205863952637, "kernel": 51.565425872802734}}, "feed_forward": {"intermediate_dense": {"bias": 2.2012171745300293, "kernel": 96.4342041015625}, "output_dense": {"bias": 0.9817801713943481, "kernel": 101.30387115478516}}, "final_layer_norm": {"bias": 1.7825044393539429, "scale": 25.214679718017578}, "layer_norm": {"bias": 2.4084482192993164, "scale": 26.425636291503906}}, "39": {"attention": {"k_proj": {"bias": 0.7209377884864807, "kernel": 45.24742889404297}, "out_proj": {"bias": 1.7097458839416504, "kernel": 51.329002380371094}, "q_proj": {"bias": 2.1152358055114746, "kernel": 45.53700637817383}, "v_proj": {"bias": 0.4246286153793335, "kernel": 51.27728271484375}}, "feed_forward": {"intermediate_dense": {"bias": 2.177624225616455, "kernel": 94.23902893066406}, "output_dense": {"bias": 1.0458879470825195, "kernel": 101.17718505859375}}, "final_layer_norm": {"bias": 1.753927230834961, "scale": 25.785865783691406}, "layer_norm": {"bias": 2.33198881149292, "scale": 26.942312240600586}}, "4": {"attention": {"k_proj": {"bias": 0.44548165798187256, "kernel": 54.6234130859375}, "out_proj": {"bias": 1.652343988418579, "kernel": 50.16497039794922}, "q_proj": {"bias": 2.615248918533325, "kernel": 54.93098068237305}, "v_proj": {"bias": 0.34892427921295166, "kernel": 50.320098876953125}}, "feed_forward": {"intermediate_dense": {"bias": 1.9531786441802979, "kernel": 104.50968933105469}, "output_dense": {"bias": 0.8575068712234497, "kernel": 95.54541015625}}, "final_layer_norm": {"bias": 1.9920002222061157, "scale": 21.200180053710938}, "layer_norm": {"bias": 2.054612159729004, "scale": 23.620338439941406}}, "40": {"attention": {"k_proj": {"bias": 0.6590453386306763, "kernel": 44.198089599609375}, "out_proj": {"bias": 1.6252505779266357, "kernel": 49.55699920654297}, "q_proj": {"bias": 1.9674756526947021, "kernel": 44.89208221435547}, "v_proj": {"bias": 0.4587768614292145, "kernel": 49.23614501953125}}, "feed_forward": {"intermediate_dense": {"bias": 2.0333969593048096, "kernel": 92.16896057128906}, "output_dense": {"bias": 1.087776780128479, "kernel": 98.40738677978516}}, "final_layer_norm": {"bias": 1.7852704524993896, "scale": 25.04292106628418}, "layer_norm": {"bias": 2.2756104469299316, "scale": 26.40799903869629}}, "41": {"attention": {"k_proj": {"bias": 1.7133712768554688, "kernel": 41.96858596801758}, "out_proj": {"bias": 1.3790823221206665, "kernel": 51.25593566894531}, "q_proj": {"bias": 1.71382737159729, "kernel": 42.56317901611328}, "v_proj": {"bias": 0.4695759415626526, "kernel": 50.369529724121094}}, "feed_forward": {"intermediate_dense": {"bias": 2.110393524169922, "kernel": 88.92567443847656}, "output_dense": {"bias": 1.1446669101715088, "kernel": 97.37409973144531}}, "final_layer_norm": {"bias": 2.23917293548584, "scale": 28.507984161376953}, "layer_norm": {"bias": 2.22525691986084, "scale": 28.246891021728516}}, "42": {"attention": {"k_proj": {"bias": 0.8601109981536865, "kernel": 38.31235885620117}, "out_proj": {"bias": 1.4427157640457153, "kernel": 45.07648849487305}, "q_proj": {"bias": 1.549715280532837, "kernel": 39.524009704589844}, "v_proj": {"bias": 0.6933339834213257, "kernel": 43.49076461791992}}, "feed_forward": {"intermediate_dense": {"bias": 1.9107489585876465, "kernel": 88.009765625}, "output_dense": {"bias": 1.1978566646575928, "kernel": 95.7593994140625}}, "final_layer_norm": {"bias": 1.9227323532104492, "scale": 29.817535400390625}, "layer_norm": {"bias": 1.6761282682418823, "scale": 26.810440063476562}}, "43": {"attention": {"k_proj": {"bias": 1.247081995010376, "kernel": 34.694725036621094}, "out_proj": {"bias": 1.4174811840057373, "kernel": 41.36320495605469}, "q_proj": {"bias": 1.3773530721664429, "kernel": 35.38981628417969}, "v_proj": {"bias": 0.5787136554718018, "kernel": 39.29212951660156}}, "feed_forward": {"intermediate_dense": {"bias": 1.8936835527420044, "kernel": 87.073974609375}, "output_dense": {"bias": 0.9419379234313965, "kernel": 93.74283599853516}}, "final_layer_norm": {"bias": 1.99924898147583, "scale": 32.0491943359375}, "layer_norm": {"bias": 1.7940990924835205, "scale": 25.131242752075195}}, "44": {"attention": {"k_proj": {"bias": 2.5188145637512207, "kernel": 35.16314697265625}, "out_proj": {"bias": 1.1667875051498413, "kernel": 45.019126892089844}, "q_proj": {"bias": 1.317202091217041, "kernel": 35.58300018310547}, "v_proj": {"bias": 0.38874924182891846, "kernel": 44.14718246459961}}, "feed_forward": {"intermediate_dense": {"bias": 1.9462898969650269, "kernel": 86.07953643798828}, "output_dense": {"bias": 0.859266996383667, "kernel": 91.58765411376953}}, "final_layer_norm": {"bias": 2.0454273223876953, "scale": 34.2881965637207}, "layer_norm": {"bias": 1.6815991401672363, "scale": 25.14142608642578}}, "45": {"attention": {"k_proj": {"bias": 2.081407308578491, "kernel": 34.86139678955078}, "out_proj": {"bias": 1.0356104373931885, "kernel": 48.59937286376953}, "q_proj": {"bias": 1.402512788772583, "kernel": 35.03264617919922}, "v_proj": {"bias": 0.4231463074684143, "kernel": 48.76853942871094}}, "feed_forward": {"intermediate_dense": {"bias": 2.016927719116211, "kernel": 82.93773651123047}, "output_dense": {"bias": 0.9764893054962158, "kernel": 87.24796295166016}}, "final_layer_norm": {"bias": 1.9180456399917603, "scale": 33.143672943115234}, "layer_norm": {"bias": 1.5726068019866943, "scale": 23.782546997070312}}, "46": {"attention": {"k_proj": {"bias": 1.5659263134002686, "kernel": 35.878021240234375}, "out_proj": {"bias": 0.8182340264320374, "kernel": 51.16078186035156}, "q_proj": {"bias": 1.5642974376678467, "kernel": 36.18907165527344}, "v_proj": {"bias": 0.4092414081096649, "kernel": 51.89159393310547}}, "feed_forward": {"intermediate_dense": {"bias": 2.0093321800231934, "kernel": 77.47581481933594}, "output_dense": {"bias": 1.1406863927841187, "kernel": 77.73695373535156}}, "final_layer_norm": {"bias": 1.8108854293823242, "scale": 28.70657730102539}, "layer_norm": {"bias": 1.3991491794586182, "scale": 22.808137893676758}}, "47": {"attention": {"k_proj": {"bias": 0.6173280477523804, "kernel": 38.678985595703125}, "out_proj": {"bias": 0.6758822202682495, "kernel": 46.45281219482422}, "q_proj": {"bias": 1.7084776163101196, "kernel": 39.426841735839844}, "v_proj": {"bias": 0.4932914674282074, "kernel": 47.617279052734375}}, "feed_forward": {"intermediate_dense": {"bias": 1.986911654472351, "kernel": 75.47482299804688}, "output_dense": {"bias": 0.6346586346626282, "kernel": 72.82707214355469}}, "final_layer_norm": {"bias": 1.1888140439987183, "scale": 23.650447845458984}, "layer_norm": {"bias": 1.2521969079971313, "scale": 20.66573715209961}}, "5": {"attention": {"k_proj": {"bias": 0.42588678002357483, "kernel": 50.1945686340332}, "out_proj": {"bias": 1.6038882732391357, "kernel": 51.2144889831543}, "q_proj": {"bias": 2.7522244453430176, "kernel": 50.37500762939453}, "v_proj": {"bias": 0.3343381881713867, "kernel": 51.71652603149414}}, "feed_forward": {"intermediate_dense": {"bias": 1.8887722492218018, "kernel": 104.60663604736328}, "output_dense": {"bias": 0.8976269960403442, "kernel": 94.77360534667969}}, "final_layer_norm": {"bias": 2.1965675354003906, "scale": 21.37998390197754}, "layer_norm": {"bias": 2.0435237884521484, "scale": 22.437192916870117}}, "6": {"attention": {"k_proj": {"bias": 0.4843112528324127, "kernel": 51.87700653076172}, "out_proj": {"bias": 1.5925445556640625, "kernel": 50.83113479614258}, "q_proj": {"bias": 2.7889723777770996, "kernel": 52.3514404296875}, "v_proj": {"bias": 0.3247200846672058, "kernel": 51.107398986816406}}, "feed_forward": {"intermediate_dense": {"bias": 1.8638136386871338, "kernel": 103.7142333984375}, "output_dense": {"bias": 0.752193808555603, "kernel": 94.57742309570312}}, "final_layer_norm": {"bias": 2.5145251750946045, "scale": 20.836563110351562}, "layer_norm": {"bias": 2.0285890102386475, "scale": 23.156789779663086}}, "7": {"attention": {"k_proj": {"bias": 0.5048109889030457, "kernel": 51.46453094482422}, "out_proj": {"bias": 1.4398455619812012, "kernel": 51.139068603515625}, "q_proj": {"bias": 2.550907611846924, "kernel": 51.92047119140625}, "v_proj": {"bias": 0.42719271779060364, "kernel": 50.953285217285156}}, "feed_forward": {"intermediate_dense": {"bias": 1.8739449977874756, "kernel": 103.49991607666016}, "output_dense": {"bias": 0.5876641273498535, "kernel": 94.39387512207031}}, "final_layer_norm": {"bias": 2.416801929473877, "scale": 21.010677337646484}, "layer_norm": {"bias": 1.9788501262664795, "scale": 22.19708824157715}}, "8": {"attention": {"k_proj": {"bias": 0.49711495637893677, "kernel": 51.12122344970703}, "out_proj": {"bias": 1.2548246383666992, "kernel": 51.65118408203125}, "q_proj": {"bias": 2.541980504989624, "kernel": 51.03407287597656}, "v_proj": {"bias": 0.35420340299606323, "kernel": 51.662872314453125}}, "feed_forward": {"intermediate_dense": {"bias": 1.9298467636108398, "kernel": 103.21916198730469}, "output_dense": {"bias": 0.5482766628265381, "kernel": 93.97574615478516}}, "final_layer_norm": {"bias": 2.3526053428649902, "scale": 20.7393856048584}, "layer_norm": {"bias": 1.9221248626708984, "scale": 22.40435028076172}}, "9": {"attention": {"k_proj": {"bias": 0.5231171250343323, "kernel": 52.01068878173828}, "out_proj": {"bias": 1.4968843460083008, "kernel": 52.671897888183594}, "q_proj": {"bias": 2.4629459381103516, "kernel": 52.26807403564453}, "v_proj": {"bias": 0.38445231318473816, "kernel": 52.86597442626953}}, "feed_forward": {"intermediate_dense": {"bias": 2.026733875274658, "kernel": 101.98421478271484}, "output_dense": {"bias": 0.6828575134277344, "kernel": 94.36962890625}}, "final_layer_norm": {"bias": 2.325080156326294, "scale": 20.160720825195312}, "layer_norm": {"bias": 2.0236480236053467, "scale": 24.083864212036133}}}, "pos_conv_embed": {"conv": {"bias": 5.847014427185059, "weight_g": 9.12463665008545, "weight_v": 93.52015686035156}}}, "feature_extractor": {"conv_layers": {"0": {"conv": {"bias": 2.0290679931640625, "kernel": 20.55536460876465}, "layer_norm": {"bias": 4.550922393798828, "scale": 16.167570114135742}}, "1": {"conv": {"bias": 1.7790228128433228, "kernel": 51.24136734008789}, "layer_norm": {"bias": 5.962646961212158, "scale": 23.268157958984375}}, "2": {"conv": {"bias": 1.140576720237732, "kernel": 46.50312042236328}, "layer_norm": {"bias": 4.176670551300049, "scale": 20.370853424072266}}, "3": {"conv": {"bias": 0.6725863218307495, "kernel": 44.397525787353516}, "layer_norm": {"bias": 3.888174533843994, "scale": 17.53795051574707}}, "4": {"conv": {"bias": 0.6373162269592285, "kernel": 41.314056396484375}, "layer_norm": {"bias": 2.385471820831299, "scale": 16.34571647644043}}, "5": {"conv": {"bias": 0.5147221684455872, "kernel": 37.479759216308594}, "layer_norm": {"bias": 2.020900011062622, "scale": 17.064470291137695}}, "6": {"conv": {"bias": 0.4947893023490906, "kernel": 40.64780044555664}, "layer_norm": {"bias": 0.5876954793930054, "scale": 19.058603286743164}}}}, "feature_projection": {"layer_norm": {"bias": 6.376383304595947, "scale": 16.443069458007812}, "projection": {"bias": 1.8670344352722168, "kernel": 37.218414306640625}}, "masked_spec_embed": 11.914372444152832}}, "train/learning_rate": 1.9151924789184704e-05, "train/loss": 0.204779714345932, "train/param_norm": 1241.662353515625, "_runtime": 5032, "_timestamp": 1660135191, "_step": 275600, "_wandb": {"runtime": 5033}}
wandb/run-20220810_111559-290849gb/logs/debug-internal.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a6262ea728ba649f48f692f8eb4a7d194e2bcdfe33cede302167d1a4d75ecc09
3
+ size 160449
wandb/run-20220810_111559-290849gb/logs/debug.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a1e44b22c4702845ce31095607b3ab0c73ebedf10e949748067ba3e70053ba0e
3
+ size 6378
wandb/run-20220810_111559-290849gb/run-290849gb.wandb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f22acfd4deb69c3ccc7826aa9a440ce8f7296ac3a483b3ed16e5eeab29a79ac
3
+ size 757033
wandb/run-20220810_145446-1k92sv35/files/code/run_flax_speech_recognition_ctc.py ADDED
@@ -0,0 +1,1632 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2022 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the Flax library models for connectionist temporal classification (CTC) speech recognition.
18
+ """
19
+ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
+
21
+ import logging
22
+ import math
23
+ import os
24
+ import re
25
+ import sys
26
+ import time
27
+ from dataclasses import dataclass, field
28
+ from pathlib import Path
29
+ from typing import Any, Callable, Dict, List, Optional, Union
30
+
31
+ import datasets
32
+ import numpy as np
33
+ from datasets import DatasetDict, load_dataset, load_metric
34
+ from tqdm import tqdm
35
+
36
+ import flax
37
+ import jax
38
+ import jax.numpy as jnp
39
+ import optax
40
+ import transformers
41
+ import wandb as wandb
42
+ from flax import core, jax_utils, struct, traverse_util
43
+ from flax.jax_utils import unreplicate, pad_shard_unpad
44
+ from flax.training.common_utils import get_metrics, shard, shard_prng_key
45
+ from huggingface_hub import Repository
46
+ from models import Wav2Vec2Config, FlaxWav2Vec2ForCTC
47
+ from optax._src import linear_algebra
48
+ from transformers import (
49
+ AutoFeatureExtractor,
50
+ AutoProcessor,
51
+ AutoTokenizer,
52
+ HfArgumentParser,
53
+ TrainingArguments,
54
+ is_tensorboard_available,
55
+ set_seed,
56
+ )
57
+ from transformers.file_utils import get_full_repo_name
58
+ from transformers.utils import check_min_version
59
+ from transformers.utils.versions import require_version
60
+
61
+
62
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
63
+ check_min_version("4.17.0.dev0")
64
+
65
+ require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
66
+
67
+ logger = logging.getLogger(__name__)
68
+
69
+
70
+ @flax.struct.dataclass
71
+ class ModelArguments:
72
+ """
73
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
74
+ """
75
+
76
+ model_name_or_path: str = field(
77
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
78
+ )
79
+ config_name: Optional[str] = field(
80
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
81
+ )
82
+ tokenizer_name: Optional[str] = field(
83
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
84
+ )
85
+ feature_extractor_name: Optional[str] = field(
86
+ default=None, metadata={"help": "feature extractor name or path if not the same as model_name"}
87
+ )
88
+ cache_dir: Optional[str] = field(
89
+ default=None,
90
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
91
+ )
92
+ use_fast_tokenizer: bool = field(
93
+ default=True,
94
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
95
+ )
96
+ model_revision: str = field(
97
+ default="main",
98
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
99
+ )
100
+ use_auth_token: bool = field(
101
+ default=False,
102
+ metadata={
103
+ "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
104
+ "with private models)."
105
+ },
106
+ )
107
+ freeze_feature_encoder: bool = field(
108
+ default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
109
+ )
110
+ attention_dropout: float = field(
111
+ default=0.0, metadata={"help": "The dropout ratio for the attention probabilities."}
112
+ )
113
+ activation_dropout: float = field(
114
+ default=0.1,
115
+ metadata={
116
+ "help": "The hidden activation dropout probability in the embeddings, encoder, and pooler."
117
+ },
118
+ )
119
+ hidden_dropout: float = field(
120
+ default=0.1,
121
+ metadata={
122
+ "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler."
123
+ },
124
+ )
125
+ feat_proj_dropout: float = field(
126
+ default=0.0,
127
+ metadata={
128
+ "help": "The feat proj dropout probability for feature encoder representations."
129
+ },
130
+ )
131
+ final_dropout: float = field(
132
+ default=0.0,
133
+ metadata={"help": "The dropout probability for the final projection layer."},
134
+ )
135
+ mask_time_prob: float = field(
136
+ default=0.1,
137
+ metadata={
138
+ "help": "The spec aug dropout probability for feature encoder representations."
139
+ },
140
+ )
141
+ mask_time_length: int = field(
142
+ default=10,
143
+ metadata={"help": "Length of vector span to mask along the time axis."},
144
+ )
145
+ mask_feature_prob: float = field(
146
+ default=0.0,
147
+ metadata={
148
+ "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector"
149
+ "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis."
150
+ },
151
+ )
152
+ mask_feature_length: int = field(
153
+ default=10,
154
+ metadata={"help": "Length of vector span to mask along the feature axis."},
155
+ )
156
+ layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."})
157
+ ctc_loss_reduction: Optional[str] = field(
158
+ default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
159
+ )
160
+ ctc_zero_infinity: Optional[bool] = field(
161
+ default=False, metadata={"help": "If True, will try yo aboud the CTC loss goinf to infinity."}
162
+ )
163
+
164
+
165
+ @flax.struct.dataclass
166
+ class DataTrainingArguments:
167
+ """
168
+ Arguments pertaining to what data we are going to input our model for training and eval.
169
+ """
170
+
171
+ dataset_name: str = field(
172
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
173
+ )
174
+ dataset_config_name: Optional[str] = field(
175
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
176
+ )
177
+ text_column: Optional[str] = field(
178
+ default=None,
179
+ metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
180
+ )
181
+ dataset_cache_dir: Optional[str] = field(
182
+ default=None, metadata={"help": "Path to cache directory for saving and loading datasets"}
183
+ )
184
+ overwrite_cache: bool = field(
185
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
186
+ )
187
+ preprocessing_num_workers: Optional[int] = field(
188
+ default=None,
189
+ metadata={"help": "The number of processes to use for the preprocessing."},
190
+ )
191
+ max_train_samples: Optional[int] = field(
192
+ default=None,
193
+ metadata={
194
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
195
+ "value if set."
196
+ },
197
+ )
198
+ max_eval_samples: Optional[int] = field(
199
+ default=None,
200
+ metadata={
201
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
202
+ "value if set."
203
+ },
204
+ )
205
+ max_test_samples: Optional[int] = field(
206
+ default=None,
207
+ metadata={
208
+ "help": "For debugging purposes or quicker training, truncate the number of test examples to this "
209
+ "value if set."
210
+ },
211
+ )
212
+ audio_column_name: str = field(
213
+ default="audio",
214
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
215
+ )
216
+ text_column_name: str = field(
217
+ default="text",
218
+ metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
219
+ )
220
+ max_duration_in_seconds: float = field(
221
+ default=20.0,
222
+ metadata={
223
+ "help": "Truncate audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`"
224
+ },
225
+ )
226
+ min_duration_in_seconds: float = field(
227
+ default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
228
+ )
229
+ max_label_length: Optional[int] = field(
230
+ default=512,
231
+ metadata={
232
+ "help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
233
+ "than this will be filtered."
234
+ },
235
+ )
236
+ min_label_length: Optional[int] = field(
237
+ default=2,
238
+ metadata={
239
+ "help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
240
+ "than this will be filtered."
241
+ },
242
+ )
243
+ pad_input_to_multiple_of: Optional[int] = field(
244
+ default=32000,
245
+ metadata={
246
+ "help": "If set will pad the input sequence to a multiple of the provided value. "
247
+ "This is important to avoid triggering recompilations on TPU."
248
+ },
249
+ )
250
+ pad_target_to_multiple_of: Optional[int] = field(
251
+ default=None,
252
+ metadata={
253
+ "help": "If set will pad the target sequence to a multiple of the provided value. "
254
+ "This is important to avoid triggering recompilations on TPU."
255
+ },
256
+ )
257
+ preprocessing_only: bool = field(
258
+ default=False,
259
+ metadata={
260
+ "help": "Whether to only do data preprocessing and skip training. "
261
+ "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
262
+ "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
263
+ "so that the cached datasets can consequently be loaded in distributed training"
264
+ },
265
+ )
266
+ train_split_name: str = field(
267
+ default="train",
268
+ metadata={
269
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
270
+ },
271
+ )
272
+ eval_split_name: str = field(
273
+ default="validation",
274
+ metadata={
275
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
276
+ },
277
+ )
278
+ do_lower_case: bool = field(
279
+ default=True,
280
+ metadata={"help": "Whether the target text should be lower cased."},
281
+ )
282
+ wandb_project: str = field(
283
+ default="flax-speech-recognition-ctc",
284
+ metadata={"help": "The name of the wandb project."},
285
+ )
286
+ wandb_name: str = field(
287
+ default=None,
288
+ metadata={"help": "The name of the wandb run."},
289
+ )
290
+ wandb_job_type: str = field(
291
+ default="CTC",
292
+ metadata={"help": "The name of the wandb job type."},
293
+ )
294
+ test_split_name: str = field(
295
+ default="test",
296
+ metadata={"help": "The name of the test data set split to use (via the datasets library). Defaults to 'test'"},
297
+ )
298
+ remove_punctuation: bool = field(
299
+ default=False, metadata={"help": "Whether or not to remove punctuation during training."}
300
+ )
301
+ skip_steps: Optional[int] = field(
302
+ default=0,
303
+ metadata={
304
+ "help": "Skip this number of steps. Useful to continue training"
305
+ },
306
+ )
307
+
308
+
309
+ # @flax.struct.dataclass
310
+ @dataclass
311
+ class FlaxTrainingArguments(TrainingArguments):
312
+ precision: str = field(
313
+ default="full",
314
+ metadata={
315
+ "help": "Whether to enable mixed-precision training. If true, the optimizer is stored in half-precision (bfloat16) and computations are executed in half-precision"
316
+ "**Note that this only specifies the dtype of the computation and optimizer state. It does not influence the dtype of model parameters.**"
317
+ },
318
+ )
319
+ matmul_precision: str = field(
320
+ default="default",
321
+ metadata={
322
+ "help": "Default floating-point precision of internal computations used in TPU matrix multiplications and convolutions. "
323
+ "This configuration option controls the default precision for JAX operations that take an optional precision argument (e.g. `lax.conv_general_dilated` and `lax.dot`). "
324
+ "This configuration option does not change the behaviours of such calls with explicit precision arguments; "
325
+ "it only changes the behaviors of calls with no such argument provided. "
326
+ "One of `['highest', 'float32', 'high', 'bfloat16_3x', 'default', 'bfloat16', 'fastest', None]`."
327
+ },
328
+ )
329
+ multisteps: bool = field(
330
+ default=False,
331
+ metadata={
332
+ "help": "Whether to use Optax MultiSteps for gradient accumulation. If `False` (default) and `gradient_accumulation_steps > 1`, "
333
+ "a custom gradient accumulation implementation will be employed."
334
+ },
335
+ )
336
+
337
+
338
+ def to_fp32(t):
339
+ return jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
340
+
341
+
342
+ def to_bf16(t):
343
+ return jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t)
344
+
345
+
346
+ class MixedPrecisionTrainState(struct.PyTreeNode):
347
+ """Train state for use with a single Optax optimizer.
348
+ Adapted from flax train_state https://github.com/google/flax/blob/main/flax/training/train_state.py
349
+
350
+ Synopsis::
351
+
352
+ state = TrainState.create(
353
+ apply_fn=model.apply,
354
+ params=variables['params'],
355
+ tx=tx)
356
+ grad_fn = jax.grad(make_loss_fn(state.apply_fn))
357
+ for batch in data:
358
+ grads = grad_fn(state.params, batch)
359
+ state = state.apply_gradients(grads=grads)
360
+
361
+ Args:
362
+ step: Counter starts at 0 and is incremented by every call to
363
+ `.apply_gradients()`.
364
+ apply_fn: Usually set to `model.apply()`. Kept in this dataclass for
365
+ convenience to have a shorter params list for the `train_step()` function
366
+ in your training loop.
367
+ params: The parameters to be updated by `tx` and used by `apply_fn`.
368
+ tx: An Optax gradient transformation.
369
+ opt_state: The state for `tx`.
370
+ dropout_rng: PRNG key for stochastic operations.
371
+ bf16: Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training.
372
+ """
373
+
374
+ step: int
375
+ apply_fn: Callable = struct.field(pytree_node=False)
376
+ get_attention_mask_fn: Callable = struct.field(pytree_node=False)
377
+ params: core.FrozenDict[str, Any]
378
+ tx: optax.GradientTransformation = struct.field(pytree_node=False)
379
+ opt_state: optax.OptState
380
+ dropout_rng: jnp.ndarray
381
+ max_grad_norm: Optional[float] = 1.0
382
+
383
+ def apply_gradients(self, *, grads, to_dtype, **kwargs):
384
+ """Updates `step`, `params`, `opt_state` and `**kwargs` in return value.
385
+
386
+ Note that internally this function calls `.tx.update()` followed by a call
387
+ to `optax.apply_updates()` to update `params` and `opt_state`.
388
+
389
+ Args:
390
+ grads: Gradients that have the same pytree structure as `.params`.
391
+ **kwargs: Additional dataclass attributes that should be `.replace()`-ed.
392
+
393
+ Returns:
394
+ An updated instance of `self` with `step` incremented by one, `params`
395
+ and `opt_state` updated by applying `grads`, and additional attributes
396
+ replaced as specified by `kwargs`.
397
+ """
398
+
399
+ # clip gradients by global l2 norm
400
+ casted_max_grad_norm = to_dtype(self.max_grad_norm)
401
+ g_norm = linear_algebra.global_norm(grads)
402
+ g_norm = jnp.maximum(casted_max_grad_norm, g_norm)
403
+ grads = jax.tree_util.tree_map(lambda t: (t / g_norm) * casted_max_grad_norm, grads)
404
+
405
+ # perform update step in fp32 and subsequently downcast optimizer states if mixed precision training
406
+ # grads and opt_state in bf16 (need to upcast), params in fp32 (leave as is)
407
+ updates, new_opt_state = self.tx.update(to_fp32(grads), to_fp32(self.opt_state), self.params)
408
+
409
+ new_params = optax.apply_updates(self.params, updates)
410
+ return self.replace(
411
+ step=self.step + 1,
412
+ params=new_params,
413
+ opt_state=to_dtype(new_opt_state),
414
+ **kwargs,
415
+ )
416
+
417
+ @classmethod
418
+ def create(cls, *, apply_fn, params, tx, to_dtype, step=0, **kwargs):
419
+ """Creates a new instance with `step=0` and initialized `opt_state`."""
420
+ # downcast optimizer state to bf16 if mixed-precision training
421
+ opt_state = tx.init(to_dtype(params)) if tx is not None else None
422
+ return cls(
423
+ step=step,
424
+ apply_fn=apply_fn,
425
+ params=params,
426
+ tx=tx,
427
+ opt_state=opt_state,
428
+ **kwargs,
429
+ )
430
+
431
+ def replicate(self):
432
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
433
+
434
+
435
+ @flax.struct.dataclass
436
+ class FlaxDataCollatorSpeechSeq2SeqWithPadding:
437
+ """
438
+ Data collator that will dynamically pad the inputs received.
439
+ Args:
440
+ processor ([`Wav2Vec2Processor`])
441
+ The processor used for proccessing the data.
442
+ decoder_start_token_id (:obj: `int`)
443
+ The begin-of-sentence of the decoder.
444
+ input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
445
+ Select a strategy to pad the returned input sequences (according to the model's padding side and padding index)
446
+ among:
447
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
448
+ sequence if provided).
449
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
450
+ maximum acceptable input length for the model if that argument is not provided.
451
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
452
+ different lengths).
453
+ target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
454
+ Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
455
+ See above for details.
456
+ max_input_length (:obj:`float`, `optional`):
457
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
458
+ pad_input_to_multiple_of (:obj:`int`, `optional`):
459
+ If set will pad the input sequence to a multiple of the provided value.
460
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
461
+ 7.5 (Volta).
462
+ pad_target_to_multiple_of (:obj:`int`, `optional`):
463
+ If set will pad the target sequence to a multiple of the provided value.
464
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
465
+ 7.5 (Volta).
466
+ """
467
+
468
+ processor: Any
469
+ input_padding: Union[bool, str] = "longest"
470
+ label_padding: Union[bool, str] = "max_length"
471
+ pad_input_to_multiple_of: Optional[int] = None
472
+ pad_to_multiple_of_label: Optional[int] = None
473
+ max_input_length: Optional[float] = None
474
+ max_label_length: Optional[float] = None
475
+
476
+ def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
477
+ # split inputs and labels since they have to be of different lengths and need
478
+ # different padding methods
479
+ input_features = [{"input_values": feature["input_values"]} for feature in features]
480
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
481
+
482
+ # reformat list to dict and set to pytorch format
483
+ batch = self.processor.feature_extractor.pad(
484
+ input_features,
485
+ max_length=self.max_input_length,
486
+ padding=self.input_padding,
487
+ pad_to_multiple_of=self.pad_input_to_multiple_of,
488
+ return_tensors="np",
489
+ )
490
+
491
+ labels_batch = self.processor.tokenizer.pad(
492
+ label_features,
493
+ max_length=self.max_label_length,
494
+ padding=self.label_padding,
495
+ pad_to_multiple_of=self.pad_to_multiple_of_label,
496
+ return_tensors="np",
497
+ )
498
+
499
+ labels = labels_batch["input_ids"]
500
+ labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1))
501
+ labels = labels.filled(fill_value=-100)
502
+
503
+ batch["labels"] = labels
504
+
505
+ return batch
506
+
507
+
508
+ def get_grouped_indices(
509
+ dataset, batch_size: int, rng: Optional[List[int]] = None, mega_batch_mult: Optional[int] = None
510
+ ) -> np.array:
511
+ """
512
+ Adapted from the `get_length_grouped_indices` function in the PyTorch Trainer utils file (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_pt_utils.py#L486)
513
+ Function that returns a list of indices in which each slice of `batch_size` consecutive indices correspond to elements of similar
514
+ lengths. To do this, the indices are:
515
+
516
+ - randomly permuted (if a JAX rng is specified)
517
+ - grouped in mega-batches of size `mega_batch_mult * batch_size`
518
+ - sorted by length in each mega-batch
519
+
520
+ The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
521
+ maximum length placed first, so that an OOM happens sooner rather than later.
522
+ """
523
+ lengths = dataset["input_length"]
524
+
525
+ # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
526
+ if mega_batch_mult is None:
527
+ mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
528
+ # Just in case, for tiny datasets
529
+ if mega_batch_mult == 0:
530
+ mega_batch_mult = 1
531
+
532
+ # We need to use JAX for the random permutation as the PRNG key will be set based on the seed outside of the sampler.
533
+ num_samples = len(lengths)
534
+ indices = jax.random.permutation(rng, np.arange(num_samples)) if rng is not None else np.arange(num_samples)
535
+
536
+ megabatch_size = mega_batch_mult * batch_size
537
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
538
+ megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches]
539
+
540
+ # The rest is to get the biggest batch first.
541
+ # Since each megabatch is sorted by descending length, the longest element is the first
542
+ megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
543
+ max_idx = np.argmax(megabatch_maximums).item()
544
+ # Switch to put the longest batch in first position
545
+ # (note that this is different to the PT grouped sampler in which we only put the longest element in the first position, and not its batch)
546
+ megabatches[0], megabatches[max_idx] = megabatches[max_idx], megabatches[0]
547
+
548
+ megabatches = np.array([i for megabatch in megabatches for i in megabatch])
549
+
550
+ return megabatches
551
+
552
+
553
+ def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray:
554
+ """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
555
+ the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned."""
556
+ num_samples = len(samples_idx)
557
+ if drop_last:
558
+ samples_to_remove = num_samples % batch_size
559
+ if samples_to_remove != 0:
560
+ samples_idx = samples_idx[:-samples_to_remove]
561
+ sections_split = num_samples // batch_size
562
+ samples_idx = samples_idx.reshape((sections_split, batch_size))
563
+ else:
564
+ sections_split = math.ceil(num_samples / batch_size)
565
+ samples_idx = np.array_split(samples_idx, sections_split)
566
+ return samples_idx
567
+
568
+
569
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
570
+ summary_writer.scalar("train_time", train_time, step)
571
+
572
+ train_metrics = get_metrics(train_metrics)
573
+ for key, vals in train_metrics.items():
574
+ tag = f"train_{key}"
575
+ for i, val in enumerate(vals):
576
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
577
+
578
+
579
+ def write_eval_metric(summary_writer, eval_metrics, step, pred_str=None):
580
+ for metric_name, value in eval_metrics.items():
581
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
582
+
583
+ if pred_str is not None:
584
+ # write output actual predictions for debugging
585
+ summary_writer.text("eval_predictions", "\n".join(pred_str), step)
586
+
587
+
588
+ def write_wandb_log(metrics, step, prefix=None):
589
+ if jax.process_index() == 0:
590
+ log_metrics = {}
591
+ for k, v in metrics.items():
592
+ if "layer" in k:
593
+ log_metrics[f"{k}/"] = v
594
+ elif prefix is not None:
595
+ log_metrics[f"{prefix}/{k}"] = v
596
+ else:
597
+ log_metrics[k] = v
598
+ wandb.log(log_metrics, step)
599
+
600
+
601
+ def write_wandb_pred(pred_str, label_str, step, num_log=50, prefix="eval"):
602
+ if jax.process_index() == 0:
603
+ # convert str data to a wandb compatible format
604
+ str_data = [[label_str[i], pred_str[i]] for i in range(len(pred_str))]
605
+ # we'll log the first 50 predictions for each epoch
606
+ wandb.log(
607
+ {
608
+ f"{prefix}/step_{int(step / 1000)}k": wandb.Table(
609
+ columns=["label_str", "pred_str"], data=str_data[:num_log]
610
+ )
611
+ },
612
+ step,
613
+ )
614
+
615
+
616
+ def create_learning_rate_fn(
617
+ num_train_steps: int, num_warmup_steps: int, learning_rate: float
618
+ ) -> Callable[[int], jnp.array]:
619
+ """Returns a linear warmup, linear_decay learning rate function."""
620
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
621
+ decay_fn = optax.linear_schedule(
622
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
623
+ )
624
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
625
+ return schedule_fn
626
+
627
+
628
+ def ctc_loss(
629
+ logits,
630
+ logits_attention_mask,
631
+ labels,
632
+ blank_id,
633
+ loss_reduction="mean",
634
+ output_emission_dict=False,
635
+ log_epsilon=-100000.0,
636
+ ):
637
+ """Computes CTC loss.
638
+ This function performs forward computation over an FSA with `N * 2` states
639
+ where `N` is the max number of labels. The states are split into two groups:
640
+ Phi states and emission states. a phi-state accepts repetition of
641
+ phi (blank)-symbols and transits to emission state when the correct label is
642
+ observed. An emission state accepts repetition of the label and transits to
643
+ the next phi states at any time (so called epsilon-transition).
644
+ Below, `B` denotes the batch size, `T` denotes the time steps in `logits`,
645
+ and `N` denotes the time steps in `labels`.
646
+ Args:
647
+ logits: (B, T, K)-array containing log-probabilities of each class.
648
+ logitpaddings: (B, T)-array. Padding indicators for `logits`.
649
+ labels: (B, N)-array containing reference integer labels.
650
+ labelpaddings: (B, N)-array. Padding indicators for `labels`. Currently,
651
+ `labels` must be right-padded, i.e. each row of `labelpaddings` must be
652
+ repetition of zeroes, followed by repetition of ones.
653
+ blank_id: Id for blank token.
654
+ loss_reduction: one of "mean", "sum", "default"
655
+ - "none": no reduction is applied.
656
+ - "mean": output loss will be divided by target lengths and then the
657
+ mean over the batch is taken.
658
+ - "sum": output loss are summed over batch
659
+ output_emission_dict: whether to output additional information about the emission probs
660
+ Returns:
661
+ A pair of `(per_seq_loss, aux)`.
662
+ per_seq_loss:
663
+ (B,)-array containing loss values for each sequence in the batch.
664
+ aux: Dictionary containing interim variables used for computing losses.
665
+ aux['logalpha_phi']: (T, B, N+1)-array. Log-forward-probabilities of each
666
+ phi-state corresponding to the n-th label.
667
+ aux['logalpha_emit']: (T, B, N)-array. Log-forward-probabilities of each
668
+ emission-state corresponding to the n-th label.
669
+ aux['logprobs_phi']: (T, B, 1)-array. Probability of the phi-symbol
670
+ corresponding to each time frame.
671
+ aux['logprobs_emit']: (T, B, N)-array. Probability of the n-th label
672
+ corresponding to each time frame.
673
+ """
674
+ # label paddings are indicated by -100
675
+ labelpaddings = labels < 0
676
+ # logit paddings are the inverse of attention_mask
677
+ logitpaddings = ~logits_attention_mask
678
+
679
+ # Copied from https://github.com/tensorflow/lingvo/blob/master/lingvo/jax/layers/ctc_objectives.py
680
+ batchsize, unused_maxinputlen, num_classes = logits.shape
681
+ batchsize_, maxlabellen = labels.shape
682
+
683
+ logprobs = jax.nn.log_softmax(logits)
684
+ labellens = maxlabellen - jnp.sum(labelpaddings, axis=1).astype(jnp.int32)
685
+
686
+ # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1].
687
+ repeat = (labels[:, :-1] == labels[:, 1:]).astype(jnp.float32)
688
+ repeat = jnp.pad(repeat, ((0, 0), (0, 1)))
689
+
690
+ logprobs_phi = logprobs[:, :, blank_id : blank_id + 1] # [B, T, 1]
691
+ logprobs_phi = jnp.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1]
692
+
693
+ one_hot = jax.nn.one_hot(labels, num_classes=num_classes) # [B, N, K]
694
+ logprobs_emit = jnp.einsum("btk,bnk->btn", logprobs, one_hot)
695
+ logprobs_emit = jnp.transpose(logprobs_emit, (1, 0, 2)) # [T, B, N]
696
+
697
+ logalpha_phi_init = jnp.ones((batchsize, maxlabellen + 1)) * log_epsilon # [B, N]
698
+ logalpha_phi_init = logalpha_phi_init.at[:, 0].set(0.0)
699
+ logalpha_emit_init = jnp.ones((batchsize, maxlabellen)) * log_epsilon # [B, N]
700
+
701
+ def loop_body(prev, x):
702
+ prev_phi, prev_emit = prev
703
+ # emit-to-phi epsilon transition, except if the next label is repetition
704
+ prev_phi_orig = prev_phi
705
+ prev_phi = prev_phi.at[:, 1:].set(jnp.logaddexp(prev_phi[:, 1:], prev_emit + log_epsilon * repeat))
706
+
707
+ logprob_emit, logprob_phi, pad = x
708
+
709
+ # phi-to-emit transition
710
+ next_emit = jnp.logaddexp(prev_phi[:, :-1] + logprob_emit, prev_emit + logprob_emit)
711
+ # self-loop transition
712
+ next_phi = prev_phi + logprob_phi
713
+ # emit-to-phi blank transition only when the next label is repetition
714
+ next_phi = next_phi.at[:, 1:].set(
715
+ jnp.logaddexp(next_phi[:, 1:], prev_emit + logprob_phi + log_epsilon * (1.0 - repeat))
716
+ )
717
+
718
+ pad = pad.reshape((batchsize, 1))
719
+ next_emit = pad * prev_emit + (1.0 - pad) * next_emit
720
+ next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi
721
+
722
+ return (next_phi, next_emit), (next_phi, next_emit)
723
+
724
+ xs = (logprobs_emit, logprobs_phi, logitpaddings.transpose((1, 0)))
725
+ _, (logalpha_phi, logalpha_emit) = jax.lax.scan(loop_body, (logalpha_phi_init, logalpha_emit_init), xs)
726
+
727
+ # last row needs to be updated with the last epsilon transition
728
+ logalpha_phi_last = logalpha_phi[-1].at[:, 1:].set(jnp.logaddexp(logalpha_phi[-1, :, 1:], logalpha_emit[-1]))
729
+ logalpha_phi = logalpha_phi.at[-1].set(logalpha_phi_last)
730
+
731
+ # extract per_seq_loss
732
+ one_hot = jax.nn.one_hot(labellens, num_classes=maxlabellen + 1) # [B, N+1]
733
+ per_seq_loss = -jnp.einsum("bn,bn->b", logalpha_phi_last, one_hot)
734
+
735
+ if loss_reduction == "mean":
736
+ target_lengths = labelpaddings.shape[-1] - labelpaddings.sum(axis=-1)
737
+ loss = (per_seq_loss / target_lengths).mean()
738
+ elif loss_reduction == "sum":
739
+ loss = per_seq_loss.sum()
740
+ else:
741
+ loss = per_seq_loss
742
+
743
+ if not output_emission_dict:
744
+ return loss
745
+
746
+ return loss, {
747
+ "logalpha_phi": logalpha_phi,
748
+ "logalpha_emit": logalpha_emit,
749
+ "logprobs_phi": logprobs_phi,
750
+ "logprobs_emit": logprobs_emit,
751
+ }
752
+
753
+
754
+ def make_dataset(data_args, seed=42):
755
+ # Pre-processing dataset
756
+ import re
757
+
758
+ def map_nst(entry):
759
+ text = entry["text"].lower()
760
+ text = text.replace("(...vær stille under dette opptaket...)", "")
761
+ text = re.sub('[áàâ]', 'a', text)
762
+ text = re.sub('[ä]', 'æ', text)
763
+ text = re.sub('[éèëê]', 'e', text)
764
+ text = re.sub('[íìïî]', 'i', text)
765
+ text = re.sub('[óòöô]', 'o', text)
766
+ text = re.sub('[ö]', 'ø', text)
767
+ text = re.sub('[ç]', 'c', text)
768
+ text = re.sub('[úùüû]', 'u', text)
769
+ # text = re.sub('\\(?=(Punktum|Komma|Utropstegn|Spørsmålstegn))', ' ', text)
770
+ text = re.sub('\s+', ' ', text)
771
+ return {"text": text}
772
+
773
+ def filter_nst(entry):
774
+ if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)):
775
+ return False # Too short
776
+ if re.match(entry["type"], "pIW|CA"):
777
+ return False # Spelling out words
778
+ return True
779
+
780
+ def filter_npsc(entry):
781
+ # False if there are digits in the text
782
+ if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)):
783
+ return False # Too short
784
+ if re.search("\d", entry["text"]):
785
+ return False
786
+ return True
787
+
788
+ def map_npsc(entry):
789
+ batch = {"text": entry["text"].lower()}
790
+ batch["text"] = re.sub('[áàâ]', 'a', batch["text"])
791
+ batch["text"] = re.sub('[ä]', 'æ', batch["text"])
792
+ batch["text"] = re.sub('[éèëê]', 'e', batch["text"])
793
+ batch["text"] = re.sub('[íìïî]', 'i', batch["text"])
794
+ batch["text"] = re.sub('[óòöô]', 'o', batch["text"])
795
+ batch["text"] = re.sub('[ö]', 'ø', batch["text"])
796
+ batch["text"] = re.sub('[ç]', 'c', batch["text"])
797
+ batch["text"] = re.sub('[úùüû]', 'u', batch["text"])
798
+ batch["text"] = re.sub('\s', ' ', batch["text"])
799
+ batch["text"] = re.sub('<ee>', 'eee', batch["text"])
800
+ batch["text"] = re.sub('<qq>', 'qqq', batch["text"])
801
+ batch["text"] = re.sub('<mm>', 'mmm', batch["text"])
802
+ batch["text"] = re.sub('<inaudible>', 'xxx', batch["text"])
803
+ # batch["text"] = re.sub('<inaudible>', '?', batch["text"])
804
+ if "<" in batch["text"]:
805
+ raise ValueError(batch["text"])
806
+ return batch
807
+
808
+ nst = datasets.load_dataset("NbAiLab/NST", "no-close")
809
+ npsc = datasets.load_dataset("NbAiLab/NPSC", "16K_mp3")
810
+ # TODO NST_hesitate
811
+
812
+ split = len(npsc[data_args.train_split_name]) / (len(npsc[data_args.train_split_name]) + len(npsc[data_args.eval_split_name])) # Use same train/val ratio as NPSC
813
+ nst_train = nst[data_args.train_split_name].train_test_split(train_size=split, seed=seed)
814
+ nst[data_args.train_split_name] = nst_train["train"]
815
+ nst[data_args.eval_split_name] = nst_train["test"]
816
+
817
+ nst = nst.filter(filter_nst).map(
818
+ map_nst,
819
+ num_proc=data_args.preprocessing_num_workers,
820
+ desc="filtering NST",
821
+ ).shuffle(seed=seed)
822
+ npsc = npsc.filter(filter_npsc).map(
823
+ map_npsc,
824
+ num_proc=data_args.preprocessing_num_workers,
825
+ desc="filtering NPSC",
826
+ ).shuffle(seed=seed)
827
+
828
+ npsc_base = npsc.remove_columns([col for col in npsc[data_args.train_split_name].column_names if col not in ["text", "audio"]])
829
+ nst_base = nst.remove_columns([col for col in nst[data_args.train_split_name].column_names if col not in ["text", "audio"]])
830
+
831
+ combined = {}
832
+ for split in data_args.train_split_name, data_args.eval_split_name, data_args.test_split_name:
833
+ probs = np.array([len(nst_base[split]), len(npsc_base[split])]) # Weight by number of examples
834
+ probs = (probs / probs.sum()).tolist()
835
+ comb = datasets.interleave_datasets([nst_base[split], npsc_base[split]], probabilities=probs, seed=seed)
836
+ combined[split] = comb
837
+
838
+ return datasets.DatasetDict(**combined)
839
+
840
+ def main():
841
+ # 1. Parse input arguments
842
+ # See all possible arguments in src/transformers/training_args.py
843
+ # or by passing the --help flag to this script.
844
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
845
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FlaxTrainingArguments))
846
+
847
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
848
+ # If we pass only one argument to the script and it's the path to a json file,
849
+ # let's parse it to get our arguments.
850
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
851
+ else:
852
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
853
+
854
+ # 2. Setup logging
855
+ # Make one log on every process with the configuration for debugging.
856
+ logging.basicConfig(
857
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
858
+ datefmt="%m/%d/%Y %H:%M:%S",
859
+ handlers=[logging.StreamHandler(sys.stdout)],
860
+ )
861
+ # Set the verbosity to info of the Transformers logger.
862
+ # We only want one process per machine to log things on the screen.
863
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
864
+ if jax.process_index() == 0:
865
+ datasets.utils.logging.set_verbosity_warning()
866
+ transformers.utils.logging.set_verbosity_info()
867
+ else:
868
+ datasets.utils.logging.set_verbosity_error()
869
+ transformers.utils.logging.set_verbosity_error()
870
+
871
+ # Set up wandb run
872
+ if jax.process_index() == 0:
873
+ wandb.init(project=data_args.wandb_project, name=data_args.wandb_name, job_type=data_args.wandb_job_type)
874
+
875
+ logger.info("Training/evaluation parameters %s", training_args)
876
+
877
+ # Set the default TPU matmul precision and display the number of devices
878
+ jax.config.update("jax_default_matmul_precision", training_args.matmul_precision)
879
+ logger.info(f"JAX devices: {jax.device_count()}, matmul precision: {training_args.matmul_precision}")
880
+
881
+ # 4. Load dataset
882
+
883
+ set_seed(training_args.seed)
884
+ raw_datasets = make_dataset(data_args, seed=training_args.seed)
885
+
886
+ # raw_datasets = DatasetDict()
887
+
888
+ # if training_args.do_train:
889
+ # raw_datasets[data_args.train_split_name] = load_dataset(
890
+ # data_args.dataset_name,
891
+ # data_args.dataset_config_name,
892
+ # split=data_args.train_split_name,
893
+ # cache_dir=data_args.dataset_cache_dir,
894
+ # use_auth_token=True if model_args.use_auth_token else None,
895
+ # )
896
+
897
+ # if training_args.do_eval:
898
+ # raw_datasets[data_args.eval_split_name] = load_dataset(
899
+ # data_args.dataset_name,
900
+ # data_args.dataset_config_name,
901
+ # split=data_args.eval_split_name,
902
+ # cache_dir=data_args.dataset_cache_dir,
903
+ # use_auth_token=True if model_args.use_auth_token else None,
904
+ # )
905
+
906
+ # if training_args.do_predict:
907
+ # test_split = data_args.test_split_name.split("+")
908
+ # for split in test_split:
909
+ # raw_datasets[split] = load_dataset(
910
+ # data_args.dataset_name,
911
+ # data_args.dataset_config_name,
912
+ # split=split,
913
+ # cache_dir=data_args.dataset_cache_dir,
914
+ # use_auth_token=True if model_args.use_auth_token else None,
915
+ # )
916
+
917
+ if not training_args.do_train and not training_args.do_eval and not training_args.do_predict:
918
+ raise ValueError(
919
+ "Cannot not train, not do evaluation and not do prediction. At least one of "
920
+ "training, evaluation or prediction has to be done."
921
+ )
922
+
923
+ # if not training, there is no need to run multiple epochs
924
+ if not training_args.do_train:
925
+ training_args.num_train_epochs = 1
926
+
927
+ if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
928
+ raise ValueError(
929
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
930
+ "Make sure to set `--audio_column_name` to the correct audio column - one of "
931
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
932
+ )
933
+
934
+ if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names:
935
+ raise ValueError(
936
+ f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
937
+ "Make sure to set `--text_column_name` to the correct text column - one of "
938
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
939
+ )
940
+
941
+ # 5. Load pretrained model, tokenizer, and feature extractor
942
+ #
943
+ # Distributed training:
944
+ # The .from_pretrained methods guarantee that only one local process can concurrently
945
+ config = Wav2Vec2Config.from_pretrained(
946
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
947
+ cache_dir=model_args.cache_dir,
948
+ revision=model_args.model_revision,
949
+ use_auth_token=True if model_args.use_auth_token else None,
950
+ )
951
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
952
+ model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
953
+ cache_dir=model_args.cache_dir,
954
+ revision=model_args.model_revision,
955
+ use_auth_token=True if model_args.use_auth_token else None,
956
+ )
957
+ tokenizer = AutoTokenizer.from_pretrained(
958
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
959
+ cache_dir=model_args.cache_dir,
960
+ revision=model_args.model_revision,
961
+ use_auth_token=True if model_args.use_auth_token else None,
962
+ )
963
+ # update config according to training args, model args, and tokenizer attributes
964
+ config.update(
965
+ {
966
+ "feat_proj_dropout": model_args.feat_proj_dropout,
967
+ "attention_dropout": model_args.attention_dropout,
968
+ "hidden_dropout": model_args.hidden_dropout,
969
+ "final_dropout": model_args.final_dropout,
970
+ "mask_time_prob": model_args.mask_time_prob,
971
+ "mask_time_length": model_args.mask_time_length,
972
+ "mask_feature_prob": model_args.mask_feature_prob,
973
+ "mask_feature_length": model_args.mask_feature_length,
974
+ "gradient_checkpointing": training_args.gradient_checkpointing,
975
+ "layerdrop": model_args.layerdrop,
976
+ "ctc_loss_reduction": model_args.ctc_loss_reduction,
977
+ "ctc_zero_infinity": model_args.ctc_zero_infinity,
978
+ "pad_token_id": tokenizer.pad_token_id,
979
+ "vocab_size": tokenizer.vocab_size, # len(tokenizer),
980
+ "activation_dropout": model_args.activation_dropout,
981
+ }
982
+ )
983
+
984
+ if tokenizer.do_lower_case and data_args.dataset_name != "librispeech_asr":
985
+ raise ValueError(
986
+ "Setting the tokenizer attribute `do_lower_case` to `True` converts all input strings to "
987
+ "uppercase prior to tokenization. This should only be done when the tokenizer is built on an uppercased corpus,"
988
+ "i.e. for the dataset `librispeech_asr` only. If your dataset is not `librispeech_asr`, the tokenizer is mostly likely "
989
+ "built on an lowercased corpus. In this case, set `tokenizer.do_lower_case` to ``False`."
990
+ )
991
+
992
+ if training_args.precision == "full_mixed":
993
+ dtype = jnp.bfloat16
994
+ training_args.mixed_precision = True
995
+ elif training_args.precision == "half_mixed":
996
+ dtype = jnp.bfloat16
997
+ training_args.mixed_precision = False
998
+ else:
999
+ dtype = jnp.float32
1000
+ training_args.mixed_precision = False
1001
+
1002
+ try:
1003
+ model = FlaxWav2Vec2ForCTC.from_pretrained(
1004
+ model_args.model_name_or_path,
1005
+ config=config,
1006
+ dtype=dtype,
1007
+ cache_dir=model_args.cache_dir,
1008
+ revision=model_args.model_revision,
1009
+ use_auth_token=True if model_args.use_auth_token else None,
1010
+ )
1011
+ except:
1012
+ model = FlaxWav2Vec2ForCTC.from_pretrained(
1013
+ model_args.model_name_or_path,
1014
+ config=config,
1015
+ dtype=dtype,
1016
+ cache_dir=model_args.cache_dir,
1017
+ revision=model_args.model_revision,
1018
+ use_auth_token=True if model_args.use_auth_token else None,
1019
+ from_pt=True,
1020
+ )
1021
+
1022
+ # 6. Resample speech dataset ALWAYS
1023
+ raw_datasets = raw_datasets.cast_column(
1024
+ data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
1025
+ )
1026
+
1027
+ # 7. Preprocessing the datasets.
1028
+ # We need to read the audio files as arrays and tokenize the targets.
1029
+ max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
1030
+ min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate)
1031
+ max_target_length = data_args.max_label_length
1032
+ min_target_length = data_args.min_label_length
1033
+ pad_input_to_multiple_of = data_args.pad_input_to_multiple_of
1034
+ audio_column_name = data_args.audio_column_name
1035
+ num_workers = data_args.preprocessing_num_workers
1036
+ text_column_name = data_args.text_column_name
1037
+ model_input_name = feature_extractor.model_input_names[0]
1038
+ do_lower_case = data_args.do_lower_case
1039
+ dataset_name = data_args.dataset_name
1040
+ chars_to_ignore = ', ? . ! - ; : " “ % ‘ ” ?'.split(" ")
1041
+ chars_to_ignore_regex = f'[{"".join(chars_to_ignore)}]'
1042
+ # gigaspeech_punctuation = {" <comma>": ",", " <period>": ".", " <questionmark>": "?", " <exclamationpoint>": "!"}
1043
+ # gigaspeech_disfluencies = ["<other>", "<sil>"]
1044
+ # swb_disfluencies = ["[noise]", "[laughter]", "[silence]", "<a_aside>", "<b_aside>", "<e_aside>", "[laughter-",
1045
+ # "[vocalized-noise]", "_1"]
1046
+ # swb_punctuations = ["{", "}", "[", "]-", "]"]
1047
+ # earnings_disfluencies = ["<crosstalk>", "<affirmative>", "<inaudible>", "inaudible", "<laugh>", "<unk>"]
1048
+ ignore_segments = ["ignore_time_segment_in_scoring", "<noise>", "<music>", "[noise]", "[laughter]", "[silence]",
1049
+ "[vocalized-noise]", "<crosstalk>", "<affirmative>", "<inaudible>", "<laugh>", "<other>", "<sil>", ""]
1050
+
1051
+ if training_args.do_train and data_args.max_train_samples is not None:
1052
+ raw_datasets[data_args.train_split_name] = raw_datasets[data_args.train_split_name].select(range(data_args.max_train_samples))
1053
+
1054
+ if training_args.do_eval and data_args.max_eval_samples is not None:
1055
+ raw_datasets[data_args.eval_split_name] = raw_datasets[data_args.eval_split_name].select(range(data_args.max_eval_samples))
1056
+
1057
+ if training_args.do_predict and data_args.max_test_samples is not None:
1058
+ raw_datasets[data_args.test_split_name] = raw_datasets[data_args.test_split_name].select(range(data_args.max_eval_samples))
1059
+
1060
+ if training_args.do_train and data_args.remove_punctuation:
1061
+
1062
+ def remove_punctuation(batch):
1063
+ batch[text_column_name] = (
1064
+ re.sub(chars_to_ignore_regex, "", batch[text_column_name]).replace("'", "").replace('"', "")
1065
+ )
1066
+
1067
+ raw_datasets[data_args.train_split_name] = raw_datasets[data_args.train_split_name].map(
1068
+ remove_punctuation,
1069
+ num_proc=data_args.preprocessing_num_workers,
1070
+ desc="removing punctuation from train split",
1071
+ )
1072
+
1073
+ # filter data where the targets are ignored in scoring
1074
+ def is_target_labels(input_str):
1075
+ return input_str.lower() not in ignore_segments
1076
+
1077
+ raw_datasets = raw_datasets.filter(
1078
+ is_target_labels,
1079
+ num_proc=num_workers,
1080
+ input_columns=[text_column_name],
1081
+ desc="filtering data where the targets are ignored in scoring",
1082
+ )
1083
+
1084
+ def prepare_dataset(batch):
1085
+ # process audio
1086
+ try:
1087
+ sample = batch[audio_column_name]
1088
+ except ValueError:
1089
+ sample = {"array": np.array([0.]), "sampling_rate": feature_extractor.sampling_rate}
1090
+ inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
1091
+ # process audio length
1092
+ batch[model_input_name] = inputs.input_values[0]
1093
+ batch["input_length"] = len(batch["input_values"])
1094
+
1095
+ # process targets
1096
+ input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
1097
+
1098
+ # if dataset_name == "google/xtreme_s":
1099
+ # # Finally, we tokenize the processed text
1100
+ # batch["labels"] = tokenizer(input_str).input_ids
1101
+ # batch["labels_length"] = len(batch["labels"])
1102
+ # return batch
1103
+
1104
+ # # Common Voice 9
1105
+ # if input_str.startswith('"') and input_str.endswith('"'):
1106
+ # # we can remove trailing quotation marks as they do not affect the transcription
1107
+ # input_str = input_str[1:-1]
1108
+ # # normalize quotation marks
1109
+ # input_str = re.sub(r'["“”]', '"', input_str)
1110
+ # # normalize apostrophes
1111
+ # input_str = re.sub(r"[’']", "'", input_str)
1112
+ # # normalize hyphens
1113
+ # input_str = re.sub(r"[—–]", "-", input_str)
1114
+ # # replace double quotation marks with single
1115
+ # input_str = input_str.replace('""', '"')
1116
+ # if dataset_name == "mozilla-foundation/common_voice_9_0" and len(input_str):
1117
+ # # for CV9, we'll normalize the text to always finish with punctuation
1118
+ # if input_str[-1] not in [".", "?", "!"]:
1119
+ # input_str = input_str + "."
1120
+
1121
+ # # TEDLIUM-3
1122
+ # # delete the <unk> token from the text and replace spaced apostrophes with un-spaced
1123
+ # input_str = input_str.replace("<unk>", "").replace(" '", "'")
1124
+
1125
+ # # GigaSpeech
1126
+ # for disfluency in gigaspeech_disfluencies:
1127
+ # input_str = input_str.replace(disfluency, "")
1128
+ # # convert spelled out punctuation to symbolic form
1129
+ # for punctuation, replacement in gigaspeech_punctuation.items():
1130
+ # input_str = input_str.replace(punctuation, replacement)
1131
+ # if dataset_name == "speechcolab/gigaspeech" and len(input_str):
1132
+ # # for GS, we'll normalize the text to always finish with punctuation
1133
+ # if input_str[-1] not in [".", "?", "!"]:
1134
+ # input_str = input_str + "."
1135
+
1136
+ # # SWB
1137
+ # for disfluency in swb_disfluencies:
1138
+ # input_str = input_str.replace(disfluency, "")
1139
+ # # remove parenthesised text (test data only)
1140
+ # input_str = re.sub("[\(].*?[\)]", "", input_str)
1141
+ # for punctuation in swb_punctuations:
1142
+ # input_str = input_str.replace(punctuation, "")
1143
+ # # replace anomalous words with their correct transcriptions
1144
+ # split_str = input_str.split("/")
1145
+ # if len(split_str) > 1:
1146
+ # input_str = " ".join(
1147
+ # [" ".join([" ".join(i.split(" ")[:-1]) for i in split_str])] + [split_str[-1].split(" ")[-1]])
1148
+
1149
+ # # Earnings 22
1150
+ # for disfluency in earnings_disfluencies:
1151
+ # input_str = input_str.replace(disfluency, "")
1152
+ # # replace mal-formatted ellipsis
1153
+ # input_str = input_str.replace("…", ".")
1154
+
1155
+ # JIWER compliance
1156
+ # remove multiple spaces
1157
+ input_str = re.sub(r"\s\s+", " ", input_str)
1158
+ # strip trailing spaces
1159
+ input_str = input_str.strip()
1160
+
1161
+ # Finally, we tokenize the processed text
1162
+ batch["labels"] = tokenizer(input_str).input_ids
1163
+ batch["labels_length"] = len(batch["labels"])
1164
+ return batch
1165
+
1166
+ vectorized_datasets = raw_datasets.map(
1167
+ prepare_dataset,
1168
+ remove_columns=next(iter(raw_datasets.values())).column_names,
1169
+ num_proc=num_workers,
1170
+ desc="preprocess dataset",
1171
+ )
1172
+
1173
+ # filter data with inputs shorter than min_input_length or longer than max_input_length
1174
+ def is_audio_in_length_range(length):
1175
+ return length > min_input_length and length < max_input_length
1176
+
1177
+ vectorized_datasets = vectorized_datasets.filter(
1178
+ is_audio_in_length_range,
1179
+ num_proc=num_workers,
1180
+ input_columns=["input_length"],
1181
+ )
1182
+
1183
+ # filter data with targets shorter than min_target_length or longer than max_target_length
1184
+ def is_labels_in_length_range(length):
1185
+ return length > min_target_length # and length < max_target_length
1186
+
1187
+ vectorized_datasets = vectorized_datasets.filter(
1188
+ is_labels_in_length_range,
1189
+ num_proc=num_workers,
1190
+ input_columns=["labels_length"],
1191
+ )
1192
+
1193
+ # for large datasets it is advised to run the preprocessing on a
1194
+ # single machine first with `args.preprocessing_only` since there will mostly likely
1195
+ # be a timeout when running the script in distributed mode.
1196
+ # In a second step `args.preprocessing_only` can then be set to `False` to load the
1197
+ # cached dataset
1198
+ if data_args.preprocessing_only:
1199
+ cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
1200
+ logger.info(f"Data preprocessing finished. Files cached at {cache}.")
1201
+ return
1202
+
1203
+ # 8. Load Metrics
1204
+ wer_metric = load_metric("wer")
1205
+ cer_metric = load_metric("cer")
1206
+
1207
+ def compute_metrics(pred_ids: List[List[int]], label_ids: List[List[int]]):
1208
+ padded_ids = np.where(np.asarray(label_ids) == -100, tokenizer.pad_token_id, np.asarray(label_ids))
1209
+
1210
+ pred_str = tokenizer.batch_decode(pred_ids)
1211
+ # we do not want to group tokens when computing the metrics
1212
+ label_str = tokenizer.batch_decode(padded_ids, group_tokens=False)
1213
+
1214
+ wer = wer_metric.compute(predictions=pred_str, references=label_str)
1215
+ cer = cer_metric.compute(predictions=pred_str, references=label_str)
1216
+
1217
+ return {"wer": wer, "cer": cer}, pred_str, label_str
1218
+
1219
+ # 9. save feature extractor, tokenizer and config
1220
+ feature_extractor.save_pretrained(training_args.output_dir)
1221
+ tokenizer.save_pretrained(training_args.output_dir)
1222
+ config.save_pretrained(training_args.output_dir)
1223
+
1224
+ processor = AutoProcessor.from_pretrained(training_args.output_dir)
1225
+
1226
+ data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding(
1227
+ processor=processor,
1228
+ input_padding="longest",
1229
+ pad_input_to_multiple_of=pad_input_to_multiple_of,
1230
+ max_label_length=data_args.max_label_length,
1231
+ )
1232
+
1233
+ # Enable tensorboard only on the master node
1234
+ has_tensorboard = is_tensorboard_available()
1235
+ if has_tensorboard and jax.process_index() == 0:
1236
+ try:
1237
+ from flax.metrics.tensorboard import SummaryWriter
1238
+
1239
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
1240
+ except ImportError as ie:
1241
+ has_tensorboard = False
1242
+ logger.warning(
1243
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
1244
+ )
1245
+ else:
1246
+ logger.warning(
1247
+ "Unable to display metrics through TensorBoard because the package is not installed: "
1248
+ "Please run `pip install tensorboard` to enable."
1249
+ )
1250
+
1251
+ # 10. Handle the repository creation
1252
+ if training_args.push_to_hub:
1253
+ with open(os.path.join(training_args.output_dir, ".gitattributes"), "r+") as f:
1254
+ git_lfs_extensions = f.read()
1255
+ if "*.wandb" not in git_lfs_extensions:
1256
+ f.write("*.wandb filter=lfs diff=lfs merge=lfs -text")
1257
+ if training_args.hub_model_id is None:
1258
+ repo_name = get_full_repo_name(
1259
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
1260
+ )
1261
+ else:
1262
+ repo_name = training_args.hub_model_id
1263
+ repo = Repository(training_args.output_dir, clone_from=repo_name)
1264
+
1265
+ # 11. Initialize our training
1266
+ rng = jax.random.PRNGKey(training_args.seed)
1267
+ rng, dropout_rng = jax.random.split(rng)
1268
+
1269
+ # Store some constants
1270
+ max_steps = int(training_args.max_steps)
1271
+ gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
1272
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
1273
+ batch_size_per_update = train_batch_size * gradient_accumulation_steps
1274
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
1275
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
1276
+ to_dtype = to_bf16 if training_args.mixed_precision else to_fp32
1277
+
1278
+ if training_args.do_train:
1279
+ num_train_samples = len(vectorized_datasets[data_args.train_split_name])
1280
+ steps_per_epoch = num_train_samples // batch_size_per_update
1281
+ if max_steps > 0:
1282
+ num_epochs = -(training_args.max_steps // -steps_per_epoch)
1283
+ total_train_steps = max_steps
1284
+ else:
1285
+ num_epochs = int(training_args.num_train_epochs)
1286
+ total_train_steps = steps_per_epoch * num_epochs
1287
+
1288
+ # Create learning rate schedule
1289
+ # Create learning rate schedule
1290
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
1291
+ total_train_steps,
1292
+ training_args.warmup_steps,
1293
+ training_args.learning_rate,
1294
+ )
1295
+
1296
+ # We use Optax's "masking" functionality to not apply weight decay
1297
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
1298
+ # mask boolean with the same structure as the parameters.
1299
+ # The mask is True for parameters that should be decayed.
1300
+ # Note that this mask is specifically adapted for FlaxWav2Vec2 and FlaxBart.
1301
+ # For FlaxT5, one should correct the layer norm parameter naming
1302
+ # accordingly - see `run_t5_mlm_flax.py` e.g.
1303
+ def decay_mask_fn(params):
1304
+ flat_params = traverse_util.flatten_dict(params)
1305
+ layer_norm_params = [
1306
+ (name, "scale")
1307
+ for name in ["layer_norm", "self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
1308
+ ]
1309
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
1310
+ return traverse_util.unflatten_dict(flat_mask)
1311
+
1312
+ if training_args.adafactor:
1313
+ # Create Adafactor optimizer
1314
+ optim = optax.adafactor(
1315
+ learning_rate=linear_decay_lr_schedule_fn,
1316
+ dtype_momentum=jnp.bfloat16 if training_args.mixed_precision else jnp.float32,
1317
+ weight_decay_rate=training_args.weight_decay,
1318
+ weight_decay_mask=decay_mask_fn,
1319
+ )
1320
+ else:
1321
+ # Create AdamW optimizer
1322
+ optim = optax.adamw(
1323
+ learning_rate=linear_decay_lr_schedule_fn,
1324
+ b1=training_args.adam_beta1,
1325
+ b2=training_args.adam_beta2,
1326
+ eps=training_args.adam_epsilon,
1327
+ weight_decay=training_args.weight_decay,
1328
+ mask=decay_mask_fn,
1329
+ )
1330
+
1331
+ # Optax MultiSteps for gradient accumulation. We'll only call this optimizer transformation if gradient accumulation is required (i.e. gradient accumulation steps > 1)
1332
+ if training_args.multisteps and gradient_accumulation_steps > 1:
1333
+ optim = optax.MultiSteps(optim, gradient_accumulation_steps, use_grad_mean=False)
1334
+ else:
1335
+ num_epochs = 0
1336
+ total_train_steps = 0
1337
+ num_train_samples = 0
1338
+ optim = None
1339
+
1340
+ # Setup train state
1341
+ state = MixedPrecisionTrainState.create(
1342
+ step=data_args.skip_steps,
1343
+ apply_fn=model.__call__,
1344
+ get_attention_mask_fn=model._get_feature_vector_attention_mask,
1345
+ params=model.params,
1346
+ tx=optim,
1347
+ to_dtype=to_dtype,
1348
+ dropout_rng=dropout_rng,
1349
+ max_grad_norm=training_args.max_grad_norm,
1350
+ )
1351
+
1352
+ # Replicate the train state on each device
1353
+ state = state.replicate()
1354
+ blank_id = model.config.pad_token_id
1355
+
1356
+ # Define gradient update step fn
1357
+ def train_step(state, batch):
1358
+ # only one single rng per grad step, with or without accumulation, as the graph should be identical over one effective training batch
1359
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
1360
+
1361
+ def compute_loss(params, minibatch):
1362
+ labels = minibatch.pop("labels")
1363
+ logits = state.apply_fn(
1364
+ **minibatch,
1365
+ params=params,
1366
+ dropout_rng=dropout_rng,
1367
+ freeze_feature_encoder=model_args.freeze_feature_encoder,
1368
+ train=True,
1369
+ )[0]
1370
+ logits_mask = state.get_attention_mask_fn(logits.shape[1], batch["attention_mask"])
1371
+ loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean")
1372
+
1373
+ return loss
1374
+
1375
+ grad_fn = jax.value_and_grad(compute_loss)
1376
+
1377
+ if gradient_accumulation_steps == 1 or training_args.multisteps:
1378
+ loss, grad = grad_fn(to_dtype(state.params), batch)
1379
+
1380
+ # Custom gradient accumulation
1381
+ else:
1382
+ # add a first dimension over gradient_accumulation_steps for minibatch slices
1383
+ batch = jax.tree_util.tree_map(
1384
+ lambda x: x.reshape(
1385
+ gradient_accumulation_steps, training_args.per_device_train_batch_size, *x.shape[1::]
1386
+ ),
1387
+ batch,
1388
+ )
1389
+
1390
+ def accum_minibatch_step(accum_grad, minibatch):
1391
+ # compute loss, num labels and grad over minibatch and accumulate
1392
+ loss, grad = grad_fn(to_dtype(state.params), minibatch)
1393
+ return jax.tree_util.tree_map(jnp.add, accum_grad, grad), loss
1394
+
1395
+ # create an initial state for accumulating losses, num labels and gradients
1396
+ init_grad = jax.tree_util.tree_map(jnp.zeros_like, to_dtype(state.params))
1397
+ # loop accum minibatch step over the number of gradient accumulation steps
1398
+ grad, loss = jax.lax.scan(accum_minibatch_step, init_grad, batch)
1399
+
1400
+ # update state
1401
+ new_state = state.apply_gradients(
1402
+ grads=grad,
1403
+ dropout_rng=new_dropout_rng,
1404
+ to_dtype=to_dtype,
1405
+ )
1406
+
1407
+ # compute gradient norms over all layers and globally for detailed monitoring
1408
+ layer_grad_norm = jax.tree_util.tree_map(jnp.linalg.norm, grad)
1409
+ logs = {
1410
+ "layer_grad_norm": layer_grad_norm,
1411
+ "grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm)),
1412
+ }
1413
+
1414
+ # compute parameter norms over all layers and globally for detailed monitoring
1415
+ layer_param_norm = jax.tree_util.tree_map(jnp.linalg.norm, new_state.params)
1416
+ logs["layer_param_norm"] = layer_param_norm
1417
+ logs["param_norm"] = jnp.linalg.norm(jax.tree_util.tree_leaves(layer_param_norm))
1418
+
1419
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
1420
+ metrics.update(logs)
1421
+
1422
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1423
+ # metrics = to_fp32(metrics)
1424
+
1425
+ return new_state, metrics
1426
+
1427
+ # Define eval fn
1428
+ def eval_step(params, batch):
1429
+ labels = batch.pop("labels")
1430
+ logits = model(**batch, params=params, train=False)[0]
1431
+
1432
+ logits_mask = model._get_feature_vector_attention_mask(logits.shape[1], batch["attention_mask"])
1433
+ loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean")
1434
+
1435
+ pred_ids = jnp.argmax(logits, axis=-1)
1436
+
1437
+ # summarize metrics
1438
+ metrics = {"loss": loss}
1439
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1440
+ # metrics = to_fp32(metrics)
1441
+ return metrics, pred_ids
1442
+
1443
+ # Create parallel version of the train and eval step
1444
+ if training_args.do_train:
1445
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
1446
+
1447
+ if training_args.do_eval:
1448
+ p_eval_step = jax.pmap(eval_step, "batch")
1449
+
1450
+ def run_evaluation(step):
1451
+ if training_args.do_eval:
1452
+ # ======================== Evaluating ==============================
1453
+ eval_metrics = []
1454
+ eval_preds = []
1455
+ eval_labels = []
1456
+
1457
+ # Generate eval set by sequentially sampling indices from the eval dataset and grouping by length
1458
+ eval_samples_idx = get_grouped_indices(vectorized_datasets[data_args.eval_split_name], eval_batch_size)
1459
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
1460
+
1461
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
1462
+ samples = [vectorized_datasets[data_args.eval_split_name][int(idx)] for idx in batch_idx]
1463
+ batch = data_collator(samples)
1464
+ labels = batch["labels"]
1465
+
1466
+ metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1467
+ eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1])))
1468
+ eval_metrics.append(metrics)
1469
+
1470
+ eval_labels.extend(labels)
1471
+
1472
+ # normalize eval metrics
1473
+ eval_metrics = get_metrics(eval_metrics)
1474
+ eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
1475
+ eval_metrics = to_fp32(eval_metrics)
1476
+
1477
+ # always run compute metrics
1478
+ error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels)
1479
+ eval_metrics.update(error_rate_metric)
1480
+ error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()])
1481
+
1482
+ # Print metrics and update progress bar
1483
+ desc = f"Step... ({step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})"
1484
+ epochs.write(desc)
1485
+ epochs.desc = desc
1486
+
1487
+ # Save metrics
1488
+ write_wandb_log(eval_metrics, step, prefix="eval")
1489
+ write_wandb_pred(pred_str, label_str, step)
1490
+ # if has_tensorboard and jax.process_index() == 0:
1491
+ # write_eval_metric(summary_writer, eval_metrics, step, pred_str=pred_str)
1492
+
1493
+ def save_checkpoint(step):
1494
+ # save and push checkpoint to the hub
1495
+ if jax.process_index() == 0:
1496
+ params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
1497
+ model.save_pretrained(training_args.output_dir, params=params)
1498
+ tokenizer.save_pretrained(training_args.output_dir)
1499
+ if training_args.push_to_hub:
1500
+ repo.push_to_hub(commit_message=f"{wandb.run.id}: saving weights and logs of step {int(step / 1000)}k", blocking=False)
1501
+
1502
+ skip_epochs = data_args.skip_steps // (num_train_samples // batch_size_per_update)
1503
+ logger.info("***** Running training *****")
1504
+ logger.info(f" Num examples = {num_train_samples}")
1505
+ logger.info(f" Num Epochs = {num_epochs}")
1506
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
1507
+ logger.info(f" Num gradient accumulation steps = {gradient_accumulation_steps}")
1508
+ logger.info(f" Total train batch size (w. parallel & distributed) = {batch_size_per_update}")
1509
+ logger.info(f" Total optimization steps = {total_train_steps}")
1510
+ logger.info(f" Gradient checkpointing: {config.gradient_checkpointing}")
1511
+ logger.info(f" Use scan: {config.use_scan}")
1512
+ logger.info(f" Fuse matmuls: {config.fuse_matmuls}")
1513
+ logger.info(f" Skipping: {data_args.skip_steps} steps ({skip_epochs} epochs)")
1514
+
1515
+ train_time = cur_step = 0
1516
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
1517
+ for epoch in epochs:
1518
+ if training_args.do_train:
1519
+ # ======================== Training ================================
1520
+ train_start = time.time()
1521
+
1522
+ if epoch < skip_epochs:
1523
+ logger.info(f"Skipping epoch {epoch + 1}")
1524
+ # Create sampling rng
1525
+ rng, input_rng = jax.random.split(rng)
1526
+ continue
1527
+
1528
+ # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length
1529
+ train_samples_idx = get_grouped_indices(vectorized_datasets[data_args.train_split_name], batch_size_per_update, input_rng)
1530
+ train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update)
1531
+
1532
+ if data_args.skip_steps > cur_step:
1533
+ logger.info(f"Skipping {data_args.skip_steps - (epoch * (num_train_samples // batch_size_per_update))} steps...")
1534
+ # Gather the indices for creating the batch and do a training step
1535
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1), 1):
1536
+ cur_step = epoch * (num_train_samples // batch_size_per_update) + step
1537
+ if cur_step <= data_args.skip_steps:
1538
+ continue
1539
+
1540
+ samples = [vectorized_datasets[data_args.train_split_name][int(idx)] for idx in batch_idx]
1541
+ batch = data_collator(samples)
1542
+ batch = shard(batch.data)
1543
+ try:
1544
+ state, train_metric = p_train_step(state, batch)
1545
+ except TypeError as e:
1546
+ logger.warning("Encountered following error: \n", e)
1547
+
1548
+
1549
+ if cur_step % training_args.logging_steps == 0:
1550
+ # Save metrics
1551
+ train_metric = unreplicate(train_metric)
1552
+ train_time += time.time() - train_start
1553
+ # need to upcast all device arrays to fp32 for wandb logging (jnp.bfloat16 not supported) -> do this here OR in train_step
1554
+ write_wandb_log(to_fp32(train_metric), cur_step, prefix=data_args.train_split_name)
1555
+ # we won't log to tensorboard for now (it is fiddly logging param and grad norms on a layer-by-layer basis)
1556
+ # if has_tensorboard and jax.process_index() == 0:
1557
+ # write_train_metric(summary_writer, train_metrics, train_time, cur_step)
1558
+
1559
+ epochs.write(
1560
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']}, Gradient Norm: {train_metric['grad_norm']})"
1561
+ )
1562
+ p_train_step.clear_cache()
1563
+
1564
+ if cur_step % total_train_steps == 0:
1565
+ break
1566
+
1567
+ if training_args.eval_steps and cur_step % training_args.eval_steps == 0:
1568
+ run_evaluation(cur_step)
1569
+
1570
+ if cur_step % training_args.save_steps == 0:
1571
+ save_checkpoint(cur_step)
1572
+
1573
+ if training_args.eval_steps == 0 and (epoch + 1) != num_epochs:
1574
+ # run evaluation at the end of the epoch if eval steps are not specified
1575
+ run_evaluation(cur_step)
1576
+ save_checkpoint(cur_step)
1577
+
1578
+ if training_args.do_train:
1579
+ save_checkpoint(cur_step)
1580
+
1581
+ cur_step = max_steps if max_steps > 0 else cur_step # set step to max steps so that eval happens in alignment with training
1582
+
1583
+ if training_args.do_eval:
1584
+ run_evaluation(cur_step)
1585
+
1586
+ # TODO: collapse 'do_predict' into the run_evaluation function
1587
+ if training_args.do_predict:
1588
+ for split in [data_args.test_split_name]:
1589
+ # ======================== Evaluating ==============================
1590
+ eval_metrics = []
1591
+ eval_preds = []
1592
+ eval_labels = []
1593
+
1594
+ # Generate eval set by sequentially sampling indices from the test dataset and grouping by length
1595
+ eval_samples_idx = get_grouped_indices(vectorized_datasets[split], eval_batch_size)
1596
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
1597
+
1598
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc=f"Predicting {split}...", position=2)):
1599
+ samples = [vectorized_datasets[split][int(idx)] for idx in batch_idx]
1600
+ batch = data_collator(samples)
1601
+ labels = batch["labels"]
1602
+
1603
+ metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1604
+ eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1])))
1605
+ eval_metrics.append(metrics)
1606
+
1607
+ eval_labels.extend(labels)
1608
+
1609
+ # normalize eval metrics
1610
+ eval_metrics = get_metrics(eval_metrics)
1611
+ eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
1612
+ eval_metrics = to_fp32(eval_metrics)
1613
+
1614
+ # always run compute metrics
1615
+ error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels)
1616
+ eval_metrics.update(error_rate_metric)
1617
+ error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()])
1618
+
1619
+ # Print metrics and update progress bar
1620
+ desc = f"Step... ({cur_step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})"
1621
+ epochs.write(desc)
1622
+ epochs.desc = desc
1623
+
1624
+ # Save metrics
1625
+ write_wandb_log(eval_metrics, cur_step, prefix=split)
1626
+ write_wandb_pred(pred_str, label_str, cur_step, prefix=split)
1627
+ # if has_tensorboard and jax.process_index() == 0:
1628
+ # write_eval_metric(summary_writer, eval_metrics, cur_step, pred_str=pred_str)
1629
+
1630
+
1631
+ if __name__ == "__main__":
1632
+ main()
wandb/run-20220810_145446-1k92sv35/files/config.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_version: 1
2
+
3
+ _wandb:
4
+ desc: null
5
+ value:
6
+ cli_version: 0.12.9
7
+ code_path: code/run_flax_speech_recognition_ctc.py
8
+ framework: huggingface
9
+ huggingface_version: 4.21.0
10
+ is_jupyter_run: false
11
+ is_kaggle_kernel: false
12
+ python_version: 3.8.10
13
+ start_time: 1660143286
14
+ t:
15
+ 1:
16
+ - 1
17
+ - 2
18
+ - 3
19
+ - 11
20
+ - 12
21
+ 2:
22
+ - 1
23
+ - 2
24
+ - 3
25
+ - 11
26
+ - 12
27
+ 3:
28
+ - 13
29
+ 4: 3.8.10
30
+ 5: 0.12.9
31
+ 6: 4.21.0
32
+ 8:
33
+ - 5
wandb/run-20220810_145446-1k92sv35/files/diff.patch ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/run.recover.sh b/run.recover.sh
2
+ index 77ad3fd..6891af1 100755
3
+ --- a/run.recover.sh
4
+ +++ b/run.recover.sh
5
+ @@ -10,10 +10,9 @@ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_c
6
+ --num_train_epochs="40" \
7
+ --per_device_train_batch_size="2" \
8
+ --per_device_eval_batch_size="2" \
9
+ - --gradient_accumulation_steps="1" \
10
+ - --precision="full_mixed" \
11
+ + --gradient_accumulation_steps="2" \
12
+ + --precision="half_mixed" \
13
+ --matmul_precision="bfloat16" \
14
+ - --multisteps \
15
+ --learning_rate="6.394633237505332e-05" \
16
+ --skip_steps="275000" \
17
+ --warmup_steps="2000" \
18
+ diff --git a/run_flax_speech_recognition_ctc.py b/run_flax_speech_recognition_ctc.py
19
+ index a330879..4d0f5fc 100644
20
+ --- a/run_flax_speech_recognition_ctc.py
21
+ +++ b/run_flax_speech_recognition_ctc.py
22
+ @@ -415,12 +415,12 @@ class MixedPrecisionTrainState(struct.PyTreeNode):
23
+ )
24
+
25
+ @classmethod
26
+ - def create(cls, *, apply_fn, params, tx, to_dtype, **kwargs):
27
+ + def create(cls, *, apply_fn, params, tx, to_dtype, step=0, **kwargs):
28
+ """Creates a new instance with `step=0` and initialized `opt_state`."""
29
+ # downcast optimizer state to bf16 if mixed-precision training
30
+ opt_state = tx.init(to_dtype(params)) if tx is not None else None
31
+ return cls(
32
+ - step=0,
33
+ + step=step,
34
+ apply_fn=apply_fn,
35
+ params=params,
36
+ tx=tx,
37
+ @@ -1339,6 +1339,7 @@ def main():
38
+
39
+ # Setup train state
40
+ state = MixedPrecisionTrainState.create(
41
+ + step=data_args.skip_steps,
42
+ apply_fn=model.__call__,
43
+ get_attention_mask_fn=model._get_feature_vector_attention_mask,
44
+ params=model.params,
45
+ @@ -1520,11 +1521,10 @@ def main():
46
+
47
+ if epoch < skip_epochs:
48
+ logger.info(f"Skipping epoch {epoch + 1}")
49
+ + # Create sampling rng
50
+ + rng, input_rng = jax.random.split(rng)
51
+ continue
52
+
53
+ - # Create sampling rng
54
+ - rng, input_rng = jax.random.split(rng)
55
+ -
56
+ # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length
57
+ train_samples_idx = get_grouped_indices(vectorized_datasets[data_args.train_split_name], batch_size_per_update, input_rng)
58
+ train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update)
59
+ @@ -1559,6 +1559,7 @@ def main():
60
+ epochs.write(
61
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']}, Gradient Norm: {train_metric['grad_norm']})"
62
+ )
63
+ + p_train_step.clear_cache()
64
+
65
+ if cur_step % total_train_steps == 0:
66
+ break
67
+ diff --git a/special_tokens_map.json b/special_tokens_map.json
68
+ index 218961f..cc1961e 100644
69
+ --- a/special_tokens_map.json
70
+ +++ b/special_tokens_map.json
71
+ @@ -399,6 +399,34 @@
72
+ "rstrip": false,
73
+ "single_word": false
74
+ },
75
+ + {
76
+ + "content": "</s>",
77
+ + "lstrip": false,
78
+ + "normalized": true,
79
+ + "rstrip": false,
80
+ + "single_word": false
81
+ + },
82
+ + {
83
+ + "content": "<s>",
84
+ + "lstrip": false,
85
+ + "normalized": true,
86
+ + "rstrip": false,
87
+ + "single_word": false
88
+ + },
89
+ + {
90
+ + "content": "</s>",
91
+ + "lstrip": false,
92
+ + "normalized": true,
93
+ + "rstrip": false,
94
+ + "single_word": false
95
+ + },
96
+ + {
97
+ + "content": "<s>",
98
+ + "lstrip": false,
99
+ + "normalized": true,
100
+ + "rstrip": false,
101
+ + "single_word": false
102
+ + },
103
+ {
104
+ "content": "</s>",
105
+ "lstrip": false,
106
+ diff --git a/wandb/debug-internal.log b/wandb/debug-internal.log
107
+ index 23926ef..aef858a 120000
108
+ --- a/wandb/debug-internal.log
109
+ +++ b/wandb/debug-internal.log
110
+ @@ -1 +1 @@
111
+ -run-20220805_230151-2y71vcu4/logs/debug-internal.log
112
+
113
+ +run-20220810_145446-1k92sv35/logs/debug-internal.log
114
+
115
+ diff --git a/wandb/debug.log b/wandb/debug.log
116
+ index 279853d..0d5686d 120000
117
+ --- a/wandb/debug.log
118
+ +++ b/wandb/debug.log
119
+ @@ -1 +1 @@
120
+ -run-20220805_230151-2y71vcu4/logs/debug.log
121
+
122
+ +run-20220810_145446-1k92sv35/logs/debug.log
123
+
124
+ diff --git a/wandb/latest-run b/wandb/latest-run
125
+ index f069a7a..3128ad6 120000
126
+ --- a/wandb/latest-run
127
+ +++ b/wandb/latest-run
128
+ @@ -1 +1 @@
129
+ -run-20220805_230151-2y71vcu4
130
+
131
+ +run-20220810_145446-1k92sv35
132
+
wandb/run-20220810_145446-1k92sv35/files/output.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c135f3ac8ed3bef82639474fe6693aa305cdd702fc964877622ffc3ae9ce5ce9
3
+ size 224313
wandb/run-20220810_145446-1k92sv35/files/requirements.txt ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.0.0
2
+ aiohttp==3.8.1
3
+ aiosignal==1.2.0
4
+ appdirs==1.4.4
5
+ astunparse==1.6.3
6
+ async-timeout==4.0.2
7
+ attrs==21.4.0
8
+ audioread==2.1.9
9
+ backcall==0.2.0
10
+ cachetools==4.2.4
11
+ certifi==2021.10.8
12
+ cffi==1.15.1
13
+ charset-normalizer==2.0.10
14
+ chex==0.1.3
15
+ click==8.0.3
16
+ cloud-tpu-client==0.10
17
+ cloud-tpu-profiler==2.4.0
18
+ clu==0.0.6
19
+ colorama==0.4.5
20
+ commonmark==0.9.1
21
+ configparser==5.2.0
22
+ contextlib2==21.6.0
23
+ cycler==0.11.0
24
+ datasets==2.4.0
25
+ decorator==5.1.0
26
+ dill==0.3.4
27
+ dm-tree==0.1.6
28
+ docker-pycreds==0.4.0
29
+ etils==0.6.0
30
+ exceptiongroup==1.0.0rc8
31
+ filelock==3.4.2
32
+ flatbuffers==2.0
33
+ flax==0.5.3
34
+ fonttools==4.28.5
35
+ frozenlist==1.2.0
36
+ fsspec==2021.11.1
37
+ future==0.18.2
38
+ gast==0.4.0
39
+ gitdb==4.0.9
40
+ gitpython==3.1.26
41
+ google-api-core==1.31.5
42
+ google-api-python-client==1.8.0
43
+ google-auth-httplib2==0.1.0
44
+ google-auth-oauthlib==0.4.6
45
+ google-auth==2.3.3
46
+ google-pasta==0.2.0
47
+ googleapis-common-protos==1.54.0
48
+ grpcio==1.43.0
49
+ h5py==3.6.0
50
+ httplib2==0.20.2
51
+ huggingface-hub==0.2.1
52
+ hypothesis==6.53.0
53
+ idna==3.3
54
+ importlib-metadata==4.10.0
55
+ importlib-resources==5.4.0
56
+ ipython==7.31.0
57
+ jax==0.3.15
58
+ jaxlib==0.3.15
59
+ jedi==0.18.1
60
+ jiwer==2.3.0
61
+ joblib==1.1.0
62
+ keras-preprocessing==1.1.2
63
+ keras==2.7.0
64
+ kiwisolver==1.3.2
65
+ libclang==12.0.0
66
+ librosa==0.9.2
67
+ libtpu-nightly==0.1.dev20220722
68
+ llvmlite==0.39.0
69
+ markdown==3.3.6
70
+ matplotlib-inline==0.1.3
71
+ matplotlib==3.5.1
72
+ ml-collections==0.1.0
73
+ msgpack==1.0.3
74
+ multidict==5.2.0
75
+ multiprocess==0.70.12.2
76
+ numba==0.56.0
77
+ numpy==1.22.0
78
+ oauth2client==4.1.3
79
+ oauthlib==3.1.1
80
+ opt-einsum==3.3.0
81
+ optax==0.1.3
82
+ packaging==21.3
83
+ pandas==1.3.5
84
+ parso==0.8.3
85
+ pathtools==0.1.2
86
+ pexpect==4.8.0
87
+ pickleshare==0.7.5
88
+ pillow==9.0.0
89
+ pip==22.2.2
90
+ pkg-resources==0.0.0
91
+ pooch==1.6.0
92
+ promise==2.3
93
+ prompt-toolkit==3.0.24
94
+ protobuf==3.19.1
95
+ psutil==5.9.0
96
+ ptyprocess==0.7.0
97
+ pyarrow==6.0.1
98
+ pyasn1-modules==0.2.8
99
+ pyasn1==0.4.8
100
+ pycparser==2.21
101
+ pyctcdecode==0.4.0
102
+ pygments==2.11.1
103
+ pygtrie==2.5.0
104
+ pyparsing==3.0.6
105
+ python-dateutil==2.8.2
106
+ python-levenshtein==0.12.2
107
+ pytz==2021.3
108
+ pyyaml==6.0
109
+ regex==2021.11.10
110
+ requests-oauthlib==1.3.0
111
+ requests==2.27.0
112
+ resampy==0.3.1
113
+ responses==0.18.0
114
+ rich==11.2.0
115
+ rsa==4.8
116
+ sacremoses==0.0.46
117
+ scikit-learn==1.1.1
118
+ scipy==1.7.3
119
+ sentry-sdk==1.5.2
120
+ setuptools==44.0.0
121
+ shortuuid==1.0.8
122
+ six==1.16.0
123
+ smmap==5.0.0
124
+ sortedcontainers==2.4.0
125
+ soundfile==0.10.3.post1
126
+ sox==1.4.1
127
+ subprocess32==3.5.4
128
+ tensorboard-data-server==0.6.1
129
+ tensorboard-plugin-wit==1.8.0
130
+ tensorboard==2.7.0
131
+ tensorflow-cpu==2.7.0
132
+ tensorflow-datasets==4.4.0
133
+ tensorflow-estimator==2.7.0
134
+ tensorflow-io-gcs-filesystem==0.23.1
135
+ tensorflow-metadata==1.5.0
136
+ tensorflow==2.7.0
137
+ tensorstore==0.1.21
138
+ termcolor==1.1.0
139
+ threadpoolctl==3.1.0
140
+ tokenizers==0.11.2
141
+ toolz==0.11.2
142
+ torch==1.12.0
143
+ torchaudio==0.12.0+cpu
144
+ tqdm==4.62.3
145
+ traitlets==5.1.1
146
+ transformers==4.21.0
147
+ typing-extensions==4.3.0
148
+ uritemplate==3.0.1
149
+ urllib3==1.26.7
150
+ wandb==0.12.9
151
+ wcwidth==0.2.5
152
+ werkzeug==2.0.2
153
+ wheel==0.37.1
154
+ wrapt==1.13.3
155
+ xxhash==2.0.2
156
+ yarl==1.7.2
157
+ yaspin==2.1.0
158
+ zipp==3.7.0
wandb/run-20220810_145446-1k92sv35/files/wandb-metadata.json ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.4.0-1043-gcp-x86_64-with-glibc2.29",
3
+ "python": "3.8.10",
4
+ "heartbeatAt": "2022-08-10T14:54:50.575340",
5
+ "startedAt": "2022-08-10T14:54:46.729335",
6
+ "docker": null,
7
+ "cpu_count": 96,
8
+ "cuda": null,
9
+ "args": [
10
+ "--model_name_or_path=./",
11
+ "--hub_model_id=NbAiLab/wav2vec2-1b-npsc-nst-tpu",
12
+ "--tokenizer_name=./",
13
+ "--output_dir=./",
14
+ "--overwrite_output_dir",
15
+ "--num_train_epochs=40",
16
+ "--per_device_train_batch_size=2",
17
+ "--per_device_eval_batch_size=2",
18
+ "--gradient_accumulation_steps=2",
19
+ "--precision=half_mixed",
20
+ "--matmul_precision=bfloat16",
21
+ "--learning_rate=6.394633237505332e-05",
22
+ "--skip_steps=275000",
23
+ "--warmup_steps=2000",
24
+ "--length_column_name=input_length",
25
+ "--evaluation_strategy=steps",
26
+ "--text_column_name=text",
27
+ "--save_steps=5000",
28
+ "--eval_steps=5000",
29
+ "--logging_steps=100",
30
+ "--layerdrop=0.041",
31
+ "--attention_dropout=0.094",
32
+ "--activation_dropout=0.055",
33
+ "--hidden_dropout=0.047",
34
+ "--save_total_limit=5",
35
+ "--freeze_feature_encoder",
36
+ "--feat_proj_dropout=0.04",
37
+ "--mask_time_prob=0.082",
38
+ "--mask_time_length=10",
39
+ "--mask_feature_prob=0.25",
40
+ "--mask_feature_length=64",
41
+ "--gradient_checkpointing",
42
+ "--min_duration_in_seconds=0.5",
43
+ "--max_duration_in_seconds=30.0",
44
+ "--use_auth_token",
45
+ "--seed=42",
46
+ "--group_by_length",
47
+ "--do_train",
48
+ "--do_eval",
49
+ "--push_to_hub",
50
+ "--preprocessing_num_workers=32",
51
+ "--ctc_zero_infinity",
52
+ "--do_lower_case",
53
+ "--wandb_project=wav2vec2",
54
+ "--wandb_name=wav2vec2-1b-npsc-nst-tpu (cont.)",
55
+ "--remove_punctuation"
56
+ ],
57
+ "state": "running",
58
+ "program": "run_flax_speech_recognition_ctc.py",
59
+ "codePath": "run_flax_speech_recognition_ctc.py",
60
+ "git": {
61
+ "remote": "https://huggingface.co/NbAiLab/wav2vec2-1b-npsc-nst-tpu",
62
+ "commit": "f624ac4bfedfbf56891676d7c5f2e37b4c8e0745"
63
+ },
64
+ "email": "versae@gmail.com",
65
+ "root": "/data/wav2vec2-1b-npsc-nst-tpu",
66
+ "host": "t1v-n-eedfb410-w-0",
67
+ "username": "javierr",
68
+ "executable": "/data/flax/bin/python"
69
+ }
wandb/run-20220810_145446-1k92sv35/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"_wandb": {"runtime": 727}}
wandb/run-20220810_145446-1k92sv35/logs/debug-internal.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d064b2b531d3a5823be66688cf52c5cc45e8f453efd03f83df0148c4827f85db
3
+ size 43560
wandb/run-20220810_145446-1k92sv35/logs/debug.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f46d292d19f5c77a9812e097b2c3160a150959386c7d3ccc7915aca3eb061632
3
+ size 6071
wandb/run-20220810_145446-1k92sv35/run-1k92sv35.wandb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b8052ef4aa82bb8b1afd0f2976795c9951b74f6c9bddf669d66afdd3bafdb85
3
+ size 238991
wandb/run-20220810_151736-2jo5la5b/files/code/run_flax_speech_recognition_ctc.py ADDED
@@ -0,0 +1,1632 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2022 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the Flax library models for connectionist temporal classification (CTC) speech recognition.
18
+ """
19
+ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
+
21
+ import logging
22
+ import math
23
+ import os
24
+ import re
25
+ import sys
26
+ import time
27
+ from dataclasses import dataclass, field
28
+ from pathlib import Path
29
+ from typing import Any, Callable, Dict, List, Optional, Union
30
+
31
+ import datasets
32
+ import numpy as np
33
+ from datasets import DatasetDict, load_dataset, load_metric
34
+ from tqdm import tqdm
35
+
36
+ import flax
37
+ import jax
38
+ import jax.numpy as jnp
39
+ import optax
40
+ import transformers
41
+ import wandb as wandb
42
+ from flax import core, jax_utils, struct, traverse_util
43
+ from flax.jax_utils import unreplicate, pad_shard_unpad
44
+ from flax.training.common_utils import get_metrics, shard, shard_prng_key
45
+ from huggingface_hub import Repository
46
+ from models import Wav2Vec2Config, FlaxWav2Vec2ForCTC
47
+ from optax._src import linear_algebra
48
+ from transformers import (
49
+ AutoFeatureExtractor,
50
+ AutoProcessor,
51
+ AutoTokenizer,
52
+ HfArgumentParser,
53
+ TrainingArguments,
54
+ is_tensorboard_available,
55
+ set_seed,
56
+ )
57
+ from transformers.file_utils import get_full_repo_name
58
+ from transformers.utils import check_min_version
59
+ from transformers.utils.versions import require_version
60
+
61
+
62
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
63
+ check_min_version("4.17.0.dev0")
64
+
65
+ require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
66
+
67
+ logger = logging.getLogger(__name__)
68
+
69
+
70
+ @flax.struct.dataclass
71
+ class ModelArguments:
72
+ """
73
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
74
+ """
75
+
76
+ model_name_or_path: str = field(
77
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
78
+ )
79
+ config_name: Optional[str] = field(
80
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
81
+ )
82
+ tokenizer_name: Optional[str] = field(
83
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
84
+ )
85
+ feature_extractor_name: Optional[str] = field(
86
+ default=None, metadata={"help": "feature extractor name or path if not the same as model_name"}
87
+ )
88
+ cache_dir: Optional[str] = field(
89
+ default=None,
90
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
91
+ )
92
+ use_fast_tokenizer: bool = field(
93
+ default=True,
94
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
95
+ )
96
+ model_revision: str = field(
97
+ default="main",
98
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
99
+ )
100
+ use_auth_token: bool = field(
101
+ default=False,
102
+ metadata={
103
+ "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
104
+ "with private models)."
105
+ },
106
+ )
107
+ freeze_feature_encoder: bool = field(
108
+ default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
109
+ )
110
+ attention_dropout: float = field(
111
+ default=0.0, metadata={"help": "The dropout ratio for the attention probabilities."}
112
+ )
113
+ activation_dropout: float = field(
114
+ default=0.1,
115
+ metadata={
116
+ "help": "The hidden activation dropout probability in the embeddings, encoder, and pooler."
117
+ },
118
+ )
119
+ hidden_dropout: float = field(
120
+ default=0.1,
121
+ metadata={
122
+ "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler."
123
+ },
124
+ )
125
+ feat_proj_dropout: float = field(
126
+ default=0.0,
127
+ metadata={
128
+ "help": "The feat proj dropout probability for feature encoder representations."
129
+ },
130
+ )
131
+ final_dropout: float = field(
132
+ default=0.0,
133
+ metadata={"help": "The dropout probability for the final projection layer."},
134
+ )
135
+ mask_time_prob: float = field(
136
+ default=0.1,
137
+ metadata={
138
+ "help": "The spec aug dropout probability for feature encoder representations."
139
+ },
140
+ )
141
+ mask_time_length: int = field(
142
+ default=10,
143
+ metadata={"help": "Length of vector span to mask along the time axis."},
144
+ )
145
+ mask_feature_prob: float = field(
146
+ default=0.0,
147
+ metadata={
148
+ "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector"
149
+ "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis."
150
+ },
151
+ )
152
+ mask_feature_length: int = field(
153
+ default=10,
154
+ metadata={"help": "Length of vector span to mask along the feature axis."},
155
+ )
156
+ layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."})
157
+ ctc_loss_reduction: Optional[str] = field(
158
+ default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
159
+ )
160
+ ctc_zero_infinity: Optional[bool] = field(
161
+ default=False, metadata={"help": "If True, will try yo aboud the CTC loss goinf to infinity."}
162
+ )
163
+
164
+
165
+ @flax.struct.dataclass
166
+ class DataTrainingArguments:
167
+ """
168
+ Arguments pertaining to what data we are going to input our model for training and eval.
169
+ """
170
+
171
+ dataset_name: str = field(
172
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
173
+ )
174
+ dataset_config_name: Optional[str] = field(
175
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
176
+ )
177
+ text_column: Optional[str] = field(
178
+ default=None,
179
+ metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
180
+ )
181
+ dataset_cache_dir: Optional[str] = field(
182
+ default=None, metadata={"help": "Path to cache directory for saving and loading datasets"}
183
+ )
184
+ overwrite_cache: bool = field(
185
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
186
+ )
187
+ preprocessing_num_workers: Optional[int] = field(
188
+ default=None,
189
+ metadata={"help": "The number of processes to use for the preprocessing."},
190
+ )
191
+ max_train_samples: Optional[int] = field(
192
+ default=None,
193
+ metadata={
194
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
195
+ "value if set."
196
+ },
197
+ )
198
+ max_eval_samples: Optional[int] = field(
199
+ default=None,
200
+ metadata={
201
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
202
+ "value if set."
203
+ },
204
+ )
205
+ max_test_samples: Optional[int] = field(
206
+ default=None,
207
+ metadata={
208
+ "help": "For debugging purposes or quicker training, truncate the number of test examples to this "
209
+ "value if set."
210
+ },
211
+ )
212
+ audio_column_name: str = field(
213
+ default="audio",
214
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
215
+ )
216
+ text_column_name: str = field(
217
+ default="text",
218
+ metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
219
+ )
220
+ max_duration_in_seconds: float = field(
221
+ default=20.0,
222
+ metadata={
223
+ "help": "Truncate audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`"
224
+ },
225
+ )
226
+ min_duration_in_seconds: float = field(
227
+ default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
228
+ )
229
+ max_label_length: Optional[int] = field(
230
+ default=512,
231
+ metadata={
232
+ "help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
233
+ "than this will be filtered."
234
+ },
235
+ )
236
+ min_label_length: Optional[int] = field(
237
+ default=2,
238
+ metadata={
239
+ "help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
240
+ "than this will be filtered."
241
+ },
242
+ )
243
+ pad_input_to_multiple_of: Optional[int] = field(
244
+ default=32000,
245
+ metadata={
246
+ "help": "If set will pad the input sequence to a multiple of the provided value. "
247
+ "This is important to avoid triggering recompilations on TPU."
248
+ },
249
+ )
250
+ pad_target_to_multiple_of: Optional[int] = field(
251
+ default=None,
252
+ metadata={
253
+ "help": "If set will pad the target sequence to a multiple of the provided value. "
254
+ "This is important to avoid triggering recompilations on TPU."
255
+ },
256
+ )
257
+ preprocessing_only: bool = field(
258
+ default=False,
259
+ metadata={
260
+ "help": "Whether to only do data preprocessing and skip training. "
261
+ "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
262
+ "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
263
+ "so that the cached datasets can consequently be loaded in distributed training"
264
+ },
265
+ )
266
+ train_split_name: str = field(
267
+ default="train",
268
+ metadata={
269
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
270
+ },
271
+ )
272
+ eval_split_name: str = field(
273
+ default="validation",
274
+ metadata={
275
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
276
+ },
277
+ )
278
+ do_lower_case: bool = field(
279
+ default=True,
280
+ metadata={"help": "Whether the target text should be lower cased."},
281
+ )
282
+ wandb_project: str = field(
283
+ default="flax-speech-recognition-ctc",
284
+ metadata={"help": "The name of the wandb project."},
285
+ )
286
+ wandb_name: str = field(
287
+ default=None,
288
+ metadata={"help": "The name of the wandb run."},
289
+ )
290
+ wandb_job_type: str = field(
291
+ default="CTC",
292
+ metadata={"help": "The name of the wandb job type."},
293
+ )
294
+ test_split_name: str = field(
295
+ default="test",
296
+ metadata={"help": "The name of the test data set split to use (via the datasets library). Defaults to 'test'"},
297
+ )
298
+ remove_punctuation: bool = field(
299
+ default=False, metadata={"help": "Whether or not to remove punctuation during training."}
300
+ )
301
+ skip_steps: Optional[int] = field(
302
+ default=0,
303
+ metadata={
304
+ "help": "Skip this number of steps. Useful to continue training"
305
+ },
306
+ )
307
+
308
+
309
+ # @flax.struct.dataclass
310
+ @dataclass
311
+ class FlaxTrainingArguments(TrainingArguments):
312
+ precision: str = field(
313
+ default="full",
314
+ metadata={
315
+ "help": "Whether to enable mixed-precision training. If true, the optimizer is stored in half-precision (bfloat16) and computations are executed in half-precision"
316
+ "**Note that this only specifies the dtype of the computation and optimizer state. It does not influence the dtype of model parameters.**"
317
+ },
318
+ )
319
+ matmul_precision: str = field(
320
+ default="default",
321
+ metadata={
322
+ "help": "Default floating-point precision of internal computations used in TPU matrix multiplications and convolutions. "
323
+ "This configuration option controls the default precision for JAX operations that take an optional precision argument (e.g. `lax.conv_general_dilated` and `lax.dot`). "
324
+ "This configuration option does not change the behaviours of such calls with explicit precision arguments; "
325
+ "it only changes the behaviors of calls with no such argument provided. "
326
+ "One of `['highest', 'float32', 'high', 'bfloat16_3x', 'default', 'bfloat16', 'fastest', None]`."
327
+ },
328
+ )
329
+ multisteps: bool = field(
330
+ default=False,
331
+ metadata={
332
+ "help": "Whether to use Optax MultiSteps for gradient accumulation. If `False` (default) and `gradient_accumulation_steps > 1`, "
333
+ "a custom gradient accumulation implementation will be employed."
334
+ },
335
+ )
336
+
337
+
338
+ def to_fp32(t):
339
+ return jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
340
+
341
+
342
+ def to_bf16(t):
343
+ return jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t)
344
+
345
+
346
+ class MixedPrecisionTrainState(struct.PyTreeNode):
347
+ """Train state for use with a single Optax optimizer.
348
+ Adapted from flax train_state https://github.com/google/flax/blob/main/flax/training/train_state.py
349
+
350
+ Synopsis::
351
+
352
+ state = TrainState.create(
353
+ apply_fn=model.apply,
354
+ params=variables['params'],
355
+ tx=tx)
356
+ grad_fn = jax.grad(make_loss_fn(state.apply_fn))
357
+ for batch in data:
358
+ grads = grad_fn(state.params, batch)
359
+ state = state.apply_gradients(grads=grads)
360
+
361
+ Args:
362
+ step: Counter starts at 0 and is incremented by every call to
363
+ `.apply_gradients()`.
364
+ apply_fn: Usually set to `model.apply()`. Kept in this dataclass for
365
+ convenience to have a shorter params list for the `train_step()` function
366
+ in your training loop.
367
+ params: The parameters to be updated by `tx` and used by `apply_fn`.
368
+ tx: An Optax gradient transformation.
369
+ opt_state: The state for `tx`.
370
+ dropout_rng: PRNG key for stochastic operations.
371
+ bf16: Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training.
372
+ """
373
+
374
+ step: int
375
+ apply_fn: Callable = struct.field(pytree_node=False)
376
+ get_attention_mask_fn: Callable = struct.field(pytree_node=False)
377
+ params: core.FrozenDict[str, Any]
378
+ tx: optax.GradientTransformation = struct.field(pytree_node=False)
379
+ opt_state: optax.OptState
380
+ dropout_rng: jnp.ndarray
381
+ max_grad_norm: Optional[float] = 1.0
382
+
383
+ def apply_gradients(self, *, grads, to_dtype, **kwargs):
384
+ """Updates `step`, `params`, `opt_state` and `**kwargs` in return value.
385
+
386
+ Note that internally this function calls `.tx.update()` followed by a call
387
+ to `optax.apply_updates()` to update `params` and `opt_state`.
388
+
389
+ Args:
390
+ grads: Gradients that have the same pytree structure as `.params`.
391
+ **kwargs: Additional dataclass attributes that should be `.replace()`-ed.
392
+
393
+ Returns:
394
+ An updated instance of `self` with `step` incremented by one, `params`
395
+ and `opt_state` updated by applying `grads`, and additional attributes
396
+ replaced as specified by `kwargs`.
397
+ """
398
+
399
+ # clip gradients by global l2 norm
400
+ casted_max_grad_norm = to_dtype(self.max_grad_norm)
401
+ g_norm = linear_algebra.global_norm(grads)
402
+ g_norm = jnp.maximum(casted_max_grad_norm, g_norm)
403
+ grads = jax.tree_util.tree_map(lambda t: (t / g_norm) * casted_max_grad_norm, grads)
404
+
405
+ # perform update step in fp32 and subsequently downcast optimizer states if mixed precision training
406
+ # grads and opt_state in bf16 (need to upcast), params in fp32 (leave as is)
407
+ updates, new_opt_state = self.tx.update(to_fp32(grads), to_fp32(self.opt_state), self.params)
408
+
409
+ new_params = optax.apply_updates(self.params, updates)
410
+ return self.replace(
411
+ step=self.step + 1,
412
+ params=new_params,
413
+ opt_state=to_dtype(new_opt_state),
414
+ **kwargs,
415
+ )
416
+
417
+ @classmethod
418
+ def create(cls, *, apply_fn, params, tx, to_dtype, step=0, **kwargs):
419
+ """Creates a new instance with `step=0` and initialized `opt_state`."""
420
+ # downcast optimizer state to bf16 if mixed-precision training
421
+ opt_state = tx.init(to_dtype(params)) if tx is not None else None
422
+ return cls(
423
+ step=step,
424
+ apply_fn=apply_fn,
425
+ params=params,
426
+ tx=tx,
427
+ opt_state=opt_state,
428
+ **kwargs,
429
+ )
430
+
431
+ def replicate(self):
432
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
433
+
434
+
435
+ @flax.struct.dataclass
436
+ class FlaxDataCollatorSpeechSeq2SeqWithPadding:
437
+ """
438
+ Data collator that will dynamically pad the inputs received.
439
+ Args:
440
+ processor ([`Wav2Vec2Processor`])
441
+ The processor used for proccessing the data.
442
+ decoder_start_token_id (:obj: `int`)
443
+ The begin-of-sentence of the decoder.
444
+ input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
445
+ Select a strategy to pad the returned input sequences (according to the model's padding side and padding index)
446
+ among:
447
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
448
+ sequence if provided).
449
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
450
+ maximum acceptable input length for the model if that argument is not provided.
451
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
452
+ different lengths).
453
+ target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
454
+ Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
455
+ See above for details.
456
+ max_input_length (:obj:`float`, `optional`):
457
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
458
+ pad_input_to_multiple_of (:obj:`int`, `optional`):
459
+ If set will pad the input sequence to a multiple of the provided value.
460
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
461
+ 7.5 (Volta).
462
+ pad_target_to_multiple_of (:obj:`int`, `optional`):
463
+ If set will pad the target sequence to a multiple of the provided value.
464
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
465
+ 7.5 (Volta).
466
+ """
467
+
468
+ processor: Any
469
+ input_padding: Union[bool, str] = "longest"
470
+ label_padding: Union[bool, str] = "max_length"
471
+ pad_input_to_multiple_of: Optional[int] = None
472
+ pad_to_multiple_of_label: Optional[int] = None
473
+ max_input_length: Optional[float] = None
474
+ max_label_length: Optional[float] = None
475
+
476
+ def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
477
+ # split inputs and labels since they have to be of different lengths and need
478
+ # different padding methods
479
+ input_features = [{"input_values": feature["input_values"]} for feature in features]
480
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
481
+
482
+ # reformat list to dict and set to pytorch format
483
+ batch = self.processor.feature_extractor.pad(
484
+ input_features,
485
+ max_length=self.max_input_length,
486
+ padding=self.input_padding,
487
+ pad_to_multiple_of=self.pad_input_to_multiple_of,
488
+ return_tensors="np",
489
+ )
490
+
491
+ labels_batch = self.processor.tokenizer.pad(
492
+ label_features,
493
+ max_length=self.max_label_length,
494
+ padding=self.label_padding,
495
+ pad_to_multiple_of=self.pad_to_multiple_of_label,
496
+ return_tensors="np",
497
+ )
498
+
499
+ labels = labels_batch["input_ids"]
500
+ labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1))
501
+ labels = labels.filled(fill_value=-100)
502
+
503
+ batch["labels"] = labels
504
+
505
+ return batch
506
+
507
+
508
+ def get_grouped_indices(
509
+ dataset, batch_size: int, rng: Optional[List[int]] = None, mega_batch_mult: Optional[int] = None
510
+ ) -> np.array:
511
+ """
512
+ Adapted from the `get_length_grouped_indices` function in the PyTorch Trainer utils file (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_pt_utils.py#L486)
513
+ Function that returns a list of indices in which each slice of `batch_size` consecutive indices correspond to elements of similar
514
+ lengths. To do this, the indices are:
515
+
516
+ - randomly permuted (if a JAX rng is specified)
517
+ - grouped in mega-batches of size `mega_batch_mult * batch_size`
518
+ - sorted by length in each mega-batch
519
+
520
+ The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
521
+ maximum length placed first, so that an OOM happens sooner rather than later.
522
+ """
523
+ lengths = dataset["input_length"]
524
+
525
+ # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
526
+ if mega_batch_mult is None:
527
+ mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
528
+ # Just in case, for tiny datasets
529
+ if mega_batch_mult == 0:
530
+ mega_batch_mult = 1
531
+
532
+ # We need to use JAX for the random permutation as the PRNG key will be set based on the seed outside of the sampler.
533
+ num_samples = len(lengths)
534
+ indices = jax.random.permutation(rng, np.arange(num_samples)) if rng is not None else np.arange(num_samples)
535
+
536
+ megabatch_size = mega_batch_mult * batch_size
537
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
538
+ megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches]
539
+
540
+ # The rest is to get the biggest batch first.
541
+ # Since each megabatch is sorted by descending length, the longest element is the first
542
+ megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
543
+ max_idx = np.argmax(megabatch_maximums).item()
544
+ # Switch to put the longest batch in first position
545
+ # (note that this is different to the PT grouped sampler in which we only put the longest element in the first position, and not its batch)
546
+ megabatches[0], megabatches[max_idx] = megabatches[max_idx], megabatches[0]
547
+
548
+ megabatches = np.array([i for megabatch in megabatches for i in megabatch])
549
+
550
+ return megabatches
551
+
552
+
553
+ def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray:
554
+ """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
555
+ the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned."""
556
+ num_samples = len(samples_idx)
557
+ if drop_last:
558
+ samples_to_remove = num_samples % batch_size
559
+ if samples_to_remove != 0:
560
+ samples_idx = samples_idx[:-samples_to_remove]
561
+ sections_split = num_samples // batch_size
562
+ samples_idx = samples_idx.reshape((sections_split, batch_size))
563
+ else:
564
+ sections_split = math.ceil(num_samples / batch_size)
565
+ samples_idx = np.array_split(samples_idx, sections_split)
566
+ return samples_idx
567
+
568
+
569
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
570
+ summary_writer.scalar("train_time", train_time, step)
571
+
572
+ train_metrics = get_metrics(train_metrics)
573
+ for key, vals in train_metrics.items():
574
+ tag = f"train_{key}"
575
+ for i, val in enumerate(vals):
576
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
577
+
578
+
579
+ def write_eval_metric(summary_writer, eval_metrics, step, pred_str=None):
580
+ for metric_name, value in eval_metrics.items():
581
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
582
+
583
+ if pred_str is not None:
584
+ # write output actual predictions for debugging
585
+ summary_writer.text("eval_predictions", "\n".join(pred_str), step)
586
+
587
+
588
+ def write_wandb_log(metrics, step, prefix=None):
589
+ if jax.process_index() == 0:
590
+ log_metrics = {}
591
+ for k, v in metrics.items():
592
+ if "layer" in k:
593
+ log_metrics[f"{k}/"] = v
594
+ elif prefix is not None:
595
+ log_metrics[f"{prefix}/{k}"] = v
596
+ else:
597
+ log_metrics[k] = v
598
+ wandb.log(log_metrics, step)
599
+
600
+
601
+ def write_wandb_pred(pred_str, label_str, step, num_log=50, prefix="eval"):
602
+ if jax.process_index() == 0:
603
+ # convert str data to a wandb compatible format
604
+ str_data = [[label_str[i], pred_str[i]] for i in range(len(pred_str))]
605
+ # we'll log the first 50 predictions for each epoch
606
+ wandb.log(
607
+ {
608
+ f"{prefix}/step_{int(step / 1000)}k": wandb.Table(
609
+ columns=["label_str", "pred_str"], data=str_data[:num_log]
610
+ )
611
+ },
612
+ step,
613
+ )
614
+
615
+
616
+ def create_learning_rate_fn(
617
+ num_train_steps: int, num_warmup_steps: int, learning_rate: float
618
+ ) -> Callable[[int], jnp.array]:
619
+ """Returns a linear warmup, linear_decay learning rate function."""
620
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
621
+ decay_fn = optax.linear_schedule(
622
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
623
+ )
624
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
625
+ return schedule_fn
626
+
627
+
628
+ def ctc_loss(
629
+ logits,
630
+ logits_attention_mask,
631
+ labels,
632
+ blank_id,
633
+ loss_reduction="mean",
634
+ output_emission_dict=False,
635
+ log_epsilon=-100000.0,
636
+ ):
637
+ """Computes CTC loss.
638
+ This function performs forward computation over an FSA with `N * 2` states
639
+ where `N` is the max number of labels. The states are split into two groups:
640
+ Phi states and emission states. a phi-state accepts repetition of
641
+ phi (blank)-symbols and transits to emission state when the correct label is
642
+ observed. An emission state accepts repetition of the label and transits to
643
+ the next phi states at any time (so called epsilon-transition).
644
+ Below, `B` denotes the batch size, `T` denotes the time steps in `logits`,
645
+ and `N` denotes the time steps in `labels`.
646
+ Args:
647
+ logits: (B, T, K)-array containing log-probabilities of each class.
648
+ logitpaddings: (B, T)-array. Padding indicators for `logits`.
649
+ labels: (B, N)-array containing reference integer labels.
650
+ labelpaddings: (B, N)-array. Padding indicators for `labels`. Currently,
651
+ `labels` must be right-padded, i.e. each row of `labelpaddings` must be
652
+ repetition of zeroes, followed by repetition of ones.
653
+ blank_id: Id for blank token.
654
+ loss_reduction: one of "mean", "sum", "default"
655
+ - "none": no reduction is applied.
656
+ - "mean": output loss will be divided by target lengths and then the
657
+ mean over the batch is taken.
658
+ - "sum": output loss are summed over batch
659
+ output_emission_dict: whether to output additional information about the emission probs
660
+ Returns:
661
+ A pair of `(per_seq_loss, aux)`.
662
+ per_seq_loss:
663
+ (B,)-array containing loss values for each sequence in the batch.
664
+ aux: Dictionary containing interim variables used for computing losses.
665
+ aux['logalpha_phi']: (T, B, N+1)-array. Log-forward-probabilities of each
666
+ phi-state corresponding to the n-th label.
667
+ aux['logalpha_emit']: (T, B, N)-array. Log-forward-probabilities of each
668
+ emission-state corresponding to the n-th label.
669
+ aux['logprobs_phi']: (T, B, 1)-array. Probability of the phi-symbol
670
+ corresponding to each time frame.
671
+ aux['logprobs_emit']: (T, B, N)-array. Probability of the n-th label
672
+ corresponding to each time frame.
673
+ """
674
+ # label paddings are indicated by -100
675
+ labelpaddings = labels < 0
676
+ # logit paddings are the inverse of attention_mask
677
+ logitpaddings = ~logits_attention_mask
678
+
679
+ # Copied from https://github.com/tensorflow/lingvo/blob/master/lingvo/jax/layers/ctc_objectives.py
680
+ batchsize, unused_maxinputlen, num_classes = logits.shape
681
+ batchsize_, maxlabellen = labels.shape
682
+
683
+ logprobs = jax.nn.log_softmax(logits)
684
+ labellens = maxlabellen - jnp.sum(labelpaddings, axis=1).astype(jnp.int32)
685
+
686
+ # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1].
687
+ repeat = (labels[:, :-1] == labels[:, 1:]).astype(jnp.float32)
688
+ repeat = jnp.pad(repeat, ((0, 0), (0, 1)))
689
+
690
+ logprobs_phi = logprobs[:, :, blank_id : blank_id + 1] # [B, T, 1]
691
+ logprobs_phi = jnp.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1]
692
+
693
+ one_hot = jax.nn.one_hot(labels, num_classes=num_classes) # [B, N, K]
694
+ logprobs_emit = jnp.einsum("btk,bnk->btn", logprobs, one_hot)
695
+ logprobs_emit = jnp.transpose(logprobs_emit, (1, 0, 2)) # [T, B, N]
696
+
697
+ logalpha_phi_init = jnp.ones((batchsize, maxlabellen + 1)) * log_epsilon # [B, N]
698
+ logalpha_phi_init = logalpha_phi_init.at[:, 0].set(0.0)
699
+ logalpha_emit_init = jnp.ones((batchsize, maxlabellen)) * log_epsilon # [B, N]
700
+
701
+ def loop_body(prev, x):
702
+ prev_phi, prev_emit = prev
703
+ # emit-to-phi epsilon transition, except if the next label is repetition
704
+ prev_phi_orig = prev_phi
705
+ prev_phi = prev_phi.at[:, 1:].set(jnp.logaddexp(prev_phi[:, 1:], prev_emit + log_epsilon * repeat))
706
+
707
+ logprob_emit, logprob_phi, pad = x
708
+
709
+ # phi-to-emit transition
710
+ next_emit = jnp.logaddexp(prev_phi[:, :-1] + logprob_emit, prev_emit + logprob_emit)
711
+ # self-loop transition
712
+ next_phi = prev_phi + logprob_phi
713
+ # emit-to-phi blank transition only when the next label is repetition
714
+ next_phi = next_phi.at[:, 1:].set(
715
+ jnp.logaddexp(next_phi[:, 1:], prev_emit + logprob_phi + log_epsilon * (1.0 - repeat))
716
+ )
717
+
718
+ pad = pad.reshape((batchsize, 1))
719
+ next_emit = pad * prev_emit + (1.0 - pad) * next_emit
720
+ next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi
721
+
722
+ return (next_phi, next_emit), (next_phi, next_emit)
723
+
724
+ xs = (logprobs_emit, logprobs_phi, logitpaddings.transpose((1, 0)))
725
+ _, (logalpha_phi, logalpha_emit) = jax.lax.scan(loop_body, (logalpha_phi_init, logalpha_emit_init), xs)
726
+
727
+ # last row needs to be updated with the last epsilon transition
728
+ logalpha_phi_last = logalpha_phi[-1].at[:, 1:].set(jnp.logaddexp(logalpha_phi[-1, :, 1:], logalpha_emit[-1]))
729
+ logalpha_phi = logalpha_phi.at[-1].set(logalpha_phi_last)
730
+
731
+ # extract per_seq_loss
732
+ one_hot = jax.nn.one_hot(labellens, num_classes=maxlabellen + 1) # [B, N+1]
733
+ per_seq_loss = -jnp.einsum("bn,bn->b", logalpha_phi_last, one_hot)
734
+
735
+ if loss_reduction == "mean":
736
+ target_lengths = labelpaddings.shape[-1] - labelpaddings.sum(axis=-1)
737
+ loss = (per_seq_loss / target_lengths).mean()
738
+ elif loss_reduction == "sum":
739
+ loss = per_seq_loss.sum()
740
+ else:
741
+ loss = per_seq_loss
742
+
743
+ if not output_emission_dict:
744
+ return loss
745
+
746
+ return loss, {
747
+ "logalpha_phi": logalpha_phi,
748
+ "logalpha_emit": logalpha_emit,
749
+ "logprobs_phi": logprobs_phi,
750
+ "logprobs_emit": logprobs_emit,
751
+ }
752
+
753
+
754
+ def make_dataset(data_args, seed=42):
755
+ # Pre-processing dataset
756
+ import re
757
+
758
+ def map_nst(entry):
759
+ text = entry["text"].lower()
760
+ text = text.replace("(...vær stille under dette opptaket...)", "")
761
+ text = re.sub('[áàâ]', 'a', text)
762
+ text = re.sub('[ä]', 'æ', text)
763
+ text = re.sub('[éèëê]', 'e', text)
764
+ text = re.sub('[íìïî]', 'i', text)
765
+ text = re.sub('[óòöô]', 'o', text)
766
+ text = re.sub('[ö]', 'ø', text)
767
+ text = re.sub('[ç]', 'c', text)
768
+ text = re.sub('[úùüû]', 'u', text)
769
+ # text = re.sub('\\(?=(Punktum|Komma|Utropstegn|Spørsmålstegn))', ' ', text)
770
+ text = re.sub('\s+', ' ', text)
771
+ return {"text": text}
772
+
773
+ def filter_nst(entry):
774
+ if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)):
775
+ return False # Too short
776
+ if re.match(entry["type"], "pIW|CA"):
777
+ return False # Spelling out words
778
+ return True
779
+
780
+ def filter_npsc(entry):
781
+ # False if there are digits in the text
782
+ if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)):
783
+ return False # Too short
784
+ if re.search("\d", entry["text"]):
785
+ return False
786
+ return True
787
+
788
+ def map_npsc(entry):
789
+ batch = {"text": entry["text"].lower()}
790
+ batch["text"] = re.sub('[áàâ]', 'a', batch["text"])
791
+ batch["text"] = re.sub('[ä]', 'æ', batch["text"])
792
+ batch["text"] = re.sub('[éèëê]', 'e', batch["text"])
793
+ batch["text"] = re.sub('[íìïî]', 'i', batch["text"])
794
+ batch["text"] = re.sub('[óòöô]', 'o', batch["text"])
795
+ batch["text"] = re.sub('[ö]', 'ø', batch["text"])
796
+ batch["text"] = re.sub('[ç]', 'c', batch["text"])
797
+ batch["text"] = re.sub('[úùüû]', 'u', batch["text"])
798
+ batch["text"] = re.sub('\s', ' ', batch["text"])
799
+ batch["text"] = re.sub('<ee>', 'eee', batch["text"])
800
+ batch["text"] = re.sub('<qq>', 'qqq', batch["text"])
801
+ batch["text"] = re.sub('<mm>', 'mmm', batch["text"])
802
+ batch["text"] = re.sub('<inaudible>', 'xxx', batch["text"])
803
+ # batch["text"] = re.sub('<inaudible>', '?', batch["text"])
804
+ if "<" in batch["text"]:
805
+ raise ValueError(batch["text"])
806
+ return batch
807
+
808
+ nst = datasets.load_dataset("NbAiLab/NST", "no-close")
809
+ npsc = datasets.load_dataset("NbAiLab/NPSC", "16K_mp3")
810
+ # TODO NST_hesitate
811
+
812
+ split = len(npsc[data_args.train_split_name]) / (len(npsc[data_args.train_split_name]) + len(npsc[data_args.eval_split_name])) # Use same train/val ratio as NPSC
813
+ nst_train = nst[data_args.train_split_name].train_test_split(train_size=split, seed=seed)
814
+ nst[data_args.train_split_name] = nst_train["train"]
815
+ nst[data_args.eval_split_name] = nst_train["test"]
816
+
817
+ nst = nst.filter(filter_nst).map(
818
+ map_nst,
819
+ num_proc=data_args.preprocessing_num_workers,
820
+ desc="filtering NST",
821
+ ).shuffle(seed=seed)
822
+ npsc = npsc.filter(filter_npsc).map(
823
+ map_npsc,
824
+ num_proc=data_args.preprocessing_num_workers,
825
+ desc="filtering NPSC",
826
+ ).shuffle(seed=seed)
827
+
828
+ npsc_base = npsc.remove_columns([col for col in npsc[data_args.train_split_name].column_names if col not in ["text", "audio"]])
829
+ nst_base = nst.remove_columns([col for col in nst[data_args.train_split_name].column_names if col not in ["text", "audio"]])
830
+
831
+ combined = {}
832
+ for split in data_args.train_split_name, data_args.eval_split_name, data_args.test_split_name:
833
+ probs = np.array([len(nst_base[split]), len(npsc_base[split])]) # Weight by number of examples
834
+ probs = (probs / probs.sum()).tolist()
835
+ comb = datasets.interleave_datasets([nst_base[split], npsc_base[split]], probabilities=probs, seed=seed)
836
+ combined[split] = comb
837
+
838
+ return datasets.DatasetDict(**combined)
839
+
840
+ def main():
841
+ # 1. Parse input arguments
842
+ # See all possible arguments in src/transformers/training_args.py
843
+ # or by passing the --help flag to this script.
844
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
845
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FlaxTrainingArguments))
846
+
847
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
848
+ # If we pass only one argument to the script and it's the path to a json file,
849
+ # let's parse it to get our arguments.
850
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
851
+ else:
852
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
853
+
854
+ # 2. Setup logging
855
+ # Make one log on every process with the configuration for debugging.
856
+ logging.basicConfig(
857
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
858
+ datefmt="%m/%d/%Y %H:%M:%S",
859
+ handlers=[logging.StreamHandler(sys.stdout)],
860
+ )
861
+ # Set the verbosity to info of the Transformers logger.
862
+ # We only want one process per machine to log things on the screen.
863
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
864
+ if jax.process_index() == 0:
865
+ datasets.utils.logging.set_verbosity_warning()
866
+ transformers.utils.logging.set_verbosity_info()
867
+ else:
868
+ datasets.utils.logging.set_verbosity_error()
869
+ transformers.utils.logging.set_verbosity_error()
870
+
871
+ # Set up wandb run
872
+ if jax.process_index() == 0:
873
+ wandb.init(project=data_args.wandb_project, name=data_args.wandb_name, job_type=data_args.wandb_job_type)
874
+
875
+ logger.info("Training/evaluation parameters %s", training_args)
876
+
877
+ # Set the default TPU matmul precision and display the number of devices
878
+ jax.config.update("jax_default_matmul_precision", training_args.matmul_precision)
879
+ logger.info(f"JAX devices: {jax.device_count()}, matmul precision: {training_args.matmul_precision}")
880
+
881
+ # 4. Load dataset
882
+
883
+ set_seed(training_args.seed)
884
+ raw_datasets = make_dataset(data_args, seed=training_args.seed)
885
+
886
+ # raw_datasets = DatasetDict()
887
+
888
+ # if training_args.do_train:
889
+ # raw_datasets[data_args.train_split_name] = load_dataset(
890
+ # data_args.dataset_name,
891
+ # data_args.dataset_config_name,
892
+ # split=data_args.train_split_name,
893
+ # cache_dir=data_args.dataset_cache_dir,
894
+ # use_auth_token=True if model_args.use_auth_token else None,
895
+ # )
896
+
897
+ # if training_args.do_eval:
898
+ # raw_datasets[data_args.eval_split_name] = load_dataset(
899
+ # data_args.dataset_name,
900
+ # data_args.dataset_config_name,
901
+ # split=data_args.eval_split_name,
902
+ # cache_dir=data_args.dataset_cache_dir,
903
+ # use_auth_token=True if model_args.use_auth_token else None,
904
+ # )
905
+
906
+ # if training_args.do_predict:
907
+ # test_split = data_args.test_split_name.split("+")
908
+ # for split in test_split:
909
+ # raw_datasets[split] = load_dataset(
910
+ # data_args.dataset_name,
911
+ # data_args.dataset_config_name,
912
+ # split=split,
913
+ # cache_dir=data_args.dataset_cache_dir,
914
+ # use_auth_token=True if model_args.use_auth_token else None,
915
+ # )
916
+
917
+ if not training_args.do_train and not training_args.do_eval and not training_args.do_predict:
918
+ raise ValueError(
919
+ "Cannot not train, not do evaluation and not do prediction. At least one of "
920
+ "training, evaluation or prediction has to be done."
921
+ )
922
+
923
+ # if not training, there is no need to run multiple epochs
924
+ if not training_args.do_train:
925
+ training_args.num_train_epochs = 1
926
+
927
+ if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
928
+ raise ValueError(
929
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
930
+ "Make sure to set `--audio_column_name` to the correct audio column - one of "
931
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
932
+ )
933
+
934
+ if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names:
935
+ raise ValueError(
936
+ f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
937
+ "Make sure to set `--text_column_name` to the correct text column - one of "
938
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
939
+ )
940
+
941
+ # 5. Load pretrained model, tokenizer, and feature extractor
942
+ #
943
+ # Distributed training:
944
+ # The .from_pretrained methods guarantee that only one local process can concurrently
945
+ config = Wav2Vec2Config.from_pretrained(
946
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
947
+ cache_dir=model_args.cache_dir,
948
+ revision=model_args.model_revision,
949
+ use_auth_token=True if model_args.use_auth_token else None,
950
+ )
951
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
952
+ model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
953
+ cache_dir=model_args.cache_dir,
954
+ revision=model_args.model_revision,
955
+ use_auth_token=True if model_args.use_auth_token else None,
956
+ )
957
+ tokenizer = AutoTokenizer.from_pretrained(
958
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
959
+ cache_dir=model_args.cache_dir,
960
+ revision=model_args.model_revision,
961
+ use_auth_token=True if model_args.use_auth_token else None,
962
+ )
963
+ # update config according to training args, model args, and tokenizer attributes
964
+ config.update(
965
+ {
966
+ "feat_proj_dropout": model_args.feat_proj_dropout,
967
+ "attention_dropout": model_args.attention_dropout,
968
+ "hidden_dropout": model_args.hidden_dropout,
969
+ "final_dropout": model_args.final_dropout,
970
+ "mask_time_prob": model_args.mask_time_prob,
971
+ "mask_time_length": model_args.mask_time_length,
972
+ "mask_feature_prob": model_args.mask_feature_prob,
973
+ "mask_feature_length": model_args.mask_feature_length,
974
+ "gradient_checkpointing": training_args.gradient_checkpointing,
975
+ "layerdrop": model_args.layerdrop,
976
+ "ctc_loss_reduction": model_args.ctc_loss_reduction,
977
+ "ctc_zero_infinity": model_args.ctc_zero_infinity,
978
+ "pad_token_id": tokenizer.pad_token_id,
979
+ "vocab_size": tokenizer.vocab_size, # len(tokenizer),
980
+ "activation_dropout": model_args.activation_dropout,
981
+ }
982
+ )
983
+
984
+ if tokenizer.do_lower_case and data_args.dataset_name != "librispeech_asr":
985
+ raise ValueError(
986
+ "Setting the tokenizer attribute `do_lower_case` to `True` converts all input strings to "
987
+ "uppercase prior to tokenization. This should only be done when the tokenizer is built on an uppercased corpus,"
988
+ "i.e. for the dataset `librispeech_asr` only. If your dataset is not `librispeech_asr`, the tokenizer is mostly likely "
989
+ "built on an lowercased corpus. In this case, set `tokenizer.do_lower_case` to ``False`."
990
+ )
991
+
992
+ if training_args.precision == "full_mixed":
993
+ dtype = jnp.bfloat16
994
+ training_args.mixed_precision = True
995
+ elif training_args.precision == "half_mixed":
996
+ dtype = jnp.bfloat16
997
+ training_args.mixed_precision = False
998
+ else:
999
+ dtype = jnp.float32
1000
+ training_args.mixed_precision = False
1001
+
1002
+ try:
1003
+ model = FlaxWav2Vec2ForCTC.from_pretrained(
1004
+ model_args.model_name_or_path,
1005
+ config=config,
1006
+ dtype=dtype,
1007
+ cache_dir=model_args.cache_dir,
1008
+ revision=model_args.model_revision,
1009
+ use_auth_token=True if model_args.use_auth_token else None,
1010
+ )
1011
+ except:
1012
+ model = FlaxWav2Vec2ForCTC.from_pretrained(
1013
+ model_args.model_name_or_path,
1014
+ config=config,
1015
+ dtype=dtype,
1016
+ cache_dir=model_args.cache_dir,
1017
+ revision=model_args.model_revision,
1018
+ use_auth_token=True if model_args.use_auth_token else None,
1019
+ from_pt=True,
1020
+ )
1021
+
1022
+ # 6. Resample speech dataset ALWAYS
1023
+ raw_datasets = raw_datasets.cast_column(
1024
+ data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
1025
+ )
1026
+
1027
+ # 7. Preprocessing the datasets.
1028
+ # We need to read the audio files as arrays and tokenize the targets.
1029
+ max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
1030
+ min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate)
1031
+ max_target_length = data_args.max_label_length
1032
+ min_target_length = data_args.min_label_length
1033
+ pad_input_to_multiple_of = data_args.pad_input_to_multiple_of
1034
+ audio_column_name = data_args.audio_column_name
1035
+ num_workers = data_args.preprocessing_num_workers
1036
+ text_column_name = data_args.text_column_name
1037
+ model_input_name = feature_extractor.model_input_names[0]
1038
+ do_lower_case = data_args.do_lower_case
1039
+ dataset_name = data_args.dataset_name
1040
+ chars_to_ignore = ', ? . ! - ; : " “ % ‘ ” ?'.split(" ")
1041
+ chars_to_ignore_regex = f'[{"".join(chars_to_ignore)}]'
1042
+ # gigaspeech_punctuation = {" <comma>": ",", " <period>": ".", " <questionmark>": "?", " <exclamationpoint>": "!"}
1043
+ # gigaspeech_disfluencies = ["<other>", "<sil>"]
1044
+ # swb_disfluencies = ["[noise]", "[laughter]", "[silence]", "<a_aside>", "<b_aside>", "<e_aside>", "[laughter-",
1045
+ # "[vocalized-noise]", "_1"]
1046
+ # swb_punctuations = ["{", "}", "[", "]-", "]"]
1047
+ # earnings_disfluencies = ["<crosstalk>", "<affirmative>", "<inaudible>", "inaudible", "<laugh>", "<unk>"]
1048
+ ignore_segments = ["ignore_time_segment_in_scoring", "<noise>", "<music>", "[noise]", "[laughter]", "[silence]",
1049
+ "[vocalized-noise]", "<crosstalk>", "<affirmative>", "<inaudible>", "<laugh>", "<other>", "<sil>", ""]
1050
+
1051
+ if training_args.do_train and data_args.max_train_samples is not None:
1052
+ raw_datasets[data_args.train_split_name] = raw_datasets[data_args.train_split_name].select(range(data_args.max_train_samples))
1053
+
1054
+ if training_args.do_eval and data_args.max_eval_samples is not None:
1055
+ raw_datasets[data_args.eval_split_name] = raw_datasets[data_args.eval_split_name].select(range(data_args.max_eval_samples))
1056
+
1057
+ if training_args.do_predict and data_args.max_test_samples is not None:
1058
+ raw_datasets[data_args.test_split_name] = raw_datasets[data_args.test_split_name].select(range(data_args.max_eval_samples))
1059
+
1060
+ if training_args.do_train and data_args.remove_punctuation:
1061
+
1062
+ def remove_punctuation(batch):
1063
+ batch[text_column_name] = (
1064
+ re.sub(chars_to_ignore_regex, "", batch[text_column_name]).replace("'", "").replace('"', "")
1065
+ )
1066
+
1067
+ raw_datasets[data_args.train_split_name] = raw_datasets[data_args.train_split_name].map(
1068
+ remove_punctuation,
1069
+ num_proc=data_args.preprocessing_num_workers,
1070
+ desc="removing punctuation from train split",
1071
+ )
1072
+
1073
+ # filter data where the targets are ignored in scoring
1074
+ def is_target_labels(input_str):
1075
+ return input_str.lower() not in ignore_segments
1076
+
1077
+ raw_datasets = raw_datasets.filter(
1078
+ is_target_labels,
1079
+ num_proc=num_workers,
1080
+ input_columns=[text_column_name],
1081
+ desc="filtering data where the targets are ignored in scoring",
1082
+ )
1083
+
1084
+ def prepare_dataset(batch):
1085
+ # process audio
1086
+ try:
1087
+ sample = batch[audio_column_name]
1088
+ except ValueError:
1089
+ sample = {"array": np.array([0.]), "sampling_rate": feature_extractor.sampling_rate}
1090
+ inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
1091
+ # process audio length
1092
+ batch[model_input_name] = inputs.input_values[0]
1093
+ batch["input_length"] = len(batch["input_values"])
1094
+
1095
+ # process targets
1096
+ input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
1097
+
1098
+ # if dataset_name == "google/xtreme_s":
1099
+ # # Finally, we tokenize the processed text
1100
+ # batch["labels"] = tokenizer(input_str).input_ids
1101
+ # batch["labels_length"] = len(batch["labels"])
1102
+ # return batch
1103
+
1104
+ # # Common Voice 9
1105
+ # if input_str.startswith('"') and input_str.endswith('"'):
1106
+ # # we can remove trailing quotation marks as they do not affect the transcription
1107
+ # input_str = input_str[1:-1]
1108
+ # # normalize quotation marks
1109
+ # input_str = re.sub(r'["“”]', '"', input_str)
1110
+ # # normalize apostrophes
1111
+ # input_str = re.sub(r"[’']", "'", input_str)
1112
+ # # normalize hyphens
1113
+ # input_str = re.sub(r"[—–]", "-", input_str)
1114
+ # # replace double quotation marks with single
1115
+ # input_str = input_str.replace('""', '"')
1116
+ # if dataset_name == "mozilla-foundation/common_voice_9_0" and len(input_str):
1117
+ # # for CV9, we'll normalize the text to always finish with punctuation
1118
+ # if input_str[-1] not in [".", "?", "!"]:
1119
+ # input_str = input_str + "."
1120
+
1121
+ # # TEDLIUM-3
1122
+ # # delete the <unk> token from the text and replace spaced apostrophes with un-spaced
1123
+ # input_str = input_str.replace("<unk>", "").replace(" '", "'")
1124
+
1125
+ # # GigaSpeech
1126
+ # for disfluency in gigaspeech_disfluencies:
1127
+ # input_str = input_str.replace(disfluency, "")
1128
+ # # convert spelled out punctuation to symbolic form
1129
+ # for punctuation, replacement in gigaspeech_punctuation.items():
1130
+ # input_str = input_str.replace(punctuation, replacement)
1131
+ # if dataset_name == "speechcolab/gigaspeech" and len(input_str):
1132
+ # # for GS, we'll normalize the text to always finish with punctuation
1133
+ # if input_str[-1] not in [".", "?", "!"]:
1134
+ # input_str = input_str + "."
1135
+
1136
+ # # SWB
1137
+ # for disfluency in swb_disfluencies:
1138
+ # input_str = input_str.replace(disfluency, "")
1139
+ # # remove parenthesised text (test data only)
1140
+ # input_str = re.sub("[\(].*?[\)]", "", input_str)
1141
+ # for punctuation in swb_punctuations:
1142
+ # input_str = input_str.replace(punctuation, "")
1143
+ # # replace anomalous words with their correct transcriptions
1144
+ # split_str = input_str.split("/")
1145
+ # if len(split_str) > 1:
1146
+ # input_str = " ".join(
1147
+ # [" ".join([" ".join(i.split(" ")[:-1]) for i in split_str])] + [split_str[-1].split(" ")[-1]])
1148
+
1149
+ # # Earnings 22
1150
+ # for disfluency in earnings_disfluencies:
1151
+ # input_str = input_str.replace(disfluency, "")
1152
+ # # replace mal-formatted ellipsis
1153
+ # input_str = input_str.replace("…", ".")
1154
+
1155
+ # JIWER compliance
1156
+ # remove multiple spaces
1157
+ input_str = re.sub(r"\s\s+", " ", input_str)
1158
+ # strip trailing spaces
1159
+ input_str = input_str.strip()
1160
+
1161
+ # Finally, we tokenize the processed text
1162
+ batch["labels"] = tokenizer(input_str).input_ids
1163
+ batch["labels_length"] = len(batch["labels"])
1164
+ return batch
1165
+
1166
+ vectorized_datasets = raw_datasets.map(
1167
+ prepare_dataset,
1168
+ remove_columns=next(iter(raw_datasets.values())).column_names,
1169
+ num_proc=num_workers,
1170
+ desc="preprocess dataset",
1171
+ )
1172
+
1173
+ # filter data with inputs shorter than min_input_length or longer than max_input_length
1174
+ def is_audio_in_length_range(length):
1175
+ return length > min_input_length and length < max_input_length
1176
+
1177
+ vectorized_datasets = vectorized_datasets.filter(
1178
+ is_audio_in_length_range,
1179
+ num_proc=num_workers,
1180
+ input_columns=["input_length"],
1181
+ )
1182
+
1183
+ # filter data with targets shorter than min_target_length or longer than max_target_length
1184
+ def is_labels_in_length_range(length):
1185
+ return length > min_target_length # and length < max_target_length
1186
+
1187
+ vectorized_datasets = vectorized_datasets.filter(
1188
+ is_labels_in_length_range,
1189
+ num_proc=num_workers,
1190
+ input_columns=["labels_length"],
1191
+ )
1192
+
1193
+ # for large datasets it is advised to run the preprocessing on a
1194
+ # single machine first with `args.preprocessing_only` since there will mostly likely
1195
+ # be a timeout when running the script in distributed mode.
1196
+ # In a second step `args.preprocessing_only` can then be set to `False` to load the
1197
+ # cached dataset
1198
+ if data_args.preprocessing_only:
1199
+ cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
1200
+ logger.info(f"Data preprocessing finished. Files cached at {cache}.")
1201
+ return
1202
+
1203
+ # 8. Load Metrics
1204
+ wer_metric = load_metric("wer")
1205
+ cer_metric = load_metric("cer")
1206
+
1207
+ def compute_metrics(pred_ids: List[List[int]], label_ids: List[List[int]]):
1208
+ padded_ids = np.where(np.asarray(label_ids) == -100, tokenizer.pad_token_id, np.asarray(label_ids))
1209
+
1210
+ pred_str = tokenizer.batch_decode(pred_ids)
1211
+ # we do not want to group tokens when computing the metrics
1212
+ label_str = tokenizer.batch_decode(padded_ids, group_tokens=False)
1213
+
1214
+ wer = wer_metric.compute(predictions=pred_str, references=label_str)
1215
+ cer = cer_metric.compute(predictions=pred_str, references=label_str)
1216
+
1217
+ return {"wer": wer, "cer": cer}, pred_str, label_str
1218
+
1219
+ # 9. save feature extractor, tokenizer and config
1220
+ feature_extractor.save_pretrained(training_args.output_dir)
1221
+ tokenizer.save_pretrained(training_args.output_dir)
1222
+ config.save_pretrained(training_args.output_dir)
1223
+
1224
+ processor = AutoProcessor.from_pretrained(training_args.output_dir)
1225
+
1226
+ data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding(
1227
+ processor=processor,
1228
+ input_padding="longest",
1229
+ pad_input_to_multiple_of=pad_input_to_multiple_of,
1230
+ max_label_length=data_args.max_label_length,
1231
+ )
1232
+
1233
+ # Enable tensorboard only on the master node
1234
+ has_tensorboard = is_tensorboard_available()
1235
+ if has_tensorboard and jax.process_index() == 0:
1236
+ try:
1237
+ from flax.metrics.tensorboard import SummaryWriter
1238
+
1239
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
1240
+ except ImportError as ie:
1241
+ has_tensorboard = False
1242
+ logger.warning(
1243
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
1244
+ )
1245
+ else:
1246
+ logger.warning(
1247
+ "Unable to display metrics through TensorBoard because the package is not installed: "
1248
+ "Please run `pip install tensorboard` to enable."
1249
+ )
1250
+
1251
+ # 10. Handle the repository creation
1252
+ if training_args.push_to_hub:
1253
+ with open(os.path.join(training_args.output_dir, ".gitattributes"), "r+") as f:
1254
+ git_lfs_extensions = f.read()
1255
+ if "*.wandb" not in git_lfs_extensions:
1256
+ f.write("*.wandb filter=lfs diff=lfs merge=lfs -text")
1257
+ if training_args.hub_model_id is None:
1258
+ repo_name = get_full_repo_name(
1259
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
1260
+ )
1261
+ else:
1262
+ repo_name = training_args.hub_model_id
1263
+ repo = Repository(training_args.output_dir, clone_from=repo_name)
1264
+
1265
+ # 11. Initialize our training
1266
+ rng = jax.random.PRNGKey(training_args.seed)
1267
+ rng, dropout_rng = jax.random.split(rng)
1268
+
1269
+ # Store some constants
1270
+ max_steps = int(training_args.max_steps)
1271
+ gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
1272
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
1273
+ batch_size_per_update = train_batch_size * gradient_accumulation_steps
1274
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
1275
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
1276
+ to_dtype = to_bf16 if training_args.mixed_precision else to_fp32
1277
+
1278
+ if training_args.do_train:
1279
+ num_train_samples = len(vectorized_datasets[data_args.train_split_name])
1280
+ steps_per_epoch = num_train_samples // batch_size_per_update
1281
+ if max_steps > 0:
1282
+ num_epochs = -(training_args.max_steps // -steps_per_epoch)
1283
+ total_train_steps = max_steps
1284
+ else:
1285
+ num_epochs = int(training_args.num_train_epochs)
1286
+ total_train_steps = steps_per_epoch * num_epochs
1287
+
1288
+ # Create learning rate schedule
1289
+ # Create learning rate schedule
1290
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
1291
+ total_train_steps,
1292
+ training_args.warmup_steps,
1293
+ training_args.learning_rate,
1294
+ )
1295
+
1296
+ # We use Optax's "masking" functionality to not apply weight decay
1297
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
1298
+ # mask boolean with the same structure as the parameters.
1299
+ # The mask is True for parameters that should be decayed.
1300
+ # Note that this mask is specifically adapted for FlaxWav2Vec2 and FlaxBart.
1301
+ # For FlaxT5, one should correct the layer norm parameter naming
1302
+ # accordingly - see `run_t5_mlm_flax.py` e.g.
1303
+ def decay_mask_fn(params):
1304
+ flat_params = traverse_util.flatten_dict(params)
1305
+ layer_norm_params = [
1306
+ (name, "scale")
1307
+ for name in ["layer_norm", "self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
1308
+ ]
1309
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
1310
+ return traverse_util.unflatten_dict(flat_mask)
1311
+
1312
+ if training_args.adafactor:
1313
+ # Create Adafactor optimizer
1314
+ optim = optax.adafactor(
1315
+ learning_rate=linear_decay_lr_schedule_fn,
1316
+ dtype_momentum=jnp.bfloat16 if training_args.mixed_precision else jnp.float32,
1317
+ weight_decay_rate=training_args.weight_decay,
1318
+ weight_decay_mask=decay_mask_fn,
1319
+ )
1320
+ else:
1321
+ # Create AdamW optimizer
1322
+ optim = optax.adamw(
1323
+ learning_rate=linear_decay_lr_schedule_fn,
1324
+ b1=training_args.adam_beta1,
1325
+ b2=training_args.adam_beta2,
1326
+ eps=training_args.adam_epsilon,
1327
+ weight_decay=training_args.weight_decay,
1328
+ mask=decay_mask_fn,
1329
+ )
1330
+
1331
+ # Optax MultiSteps for gradient accumulation. We'll only call this optimizer transformation if gradient accumulation is required (i.e. gradient accumulation steps > 1)
1332
+ if training_args.multisteps and gradient_accumulation_steps > 1:
1333
+ optim = optax.MultiSteps(optim, gradient_accumulation_steps, use_grad_mean=False)
1334
+ else:
1335
+ num_epochs = 0
1336
+ total_train_steps = 0
1337
+ num_train_samples = 0
1338
+ optim = None
1339
+
1340
+ # Setup train state
1341
+ state = MixedPrecisionTrainState.create(
1342
+ step=data_args.skip_steps,
1343
+ apply_fn=model.__call__,
1344
+ get_attention_mask_fn=model._get_feature_vector_attention_mask,
1345
+ params=model.params,
1346
+ tx=optim,
1347
+ to_dtype=to_dtype,
1348
+ dropout_rng=dropout_rng,
1349
+ max_grad_norm=training_args.max_grad_norm,
1350
+ )
1351
+
1352
+ # Replicate the train state on each device
1353
+ state = state.replicate()
1354
+ blank_id = model.config.pad_token_id
1355
+
1356
+ # Define gradient update step fn
1357
+ def train_step(state, batch):
1358
+ # only one single rng per grad step, with or without accumulation, as the graph should be identical over one effective training batch
1359
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
1360
+
1361
+ def compute_loss(params, minibatch):
1362
+ labels = minibatch.pop("labels")
1363
+ logits = state.apply_fn(
1364
+ **minibatch,
1365
+ params=params,
1366
+ dropout_rng=dropout_rng,
1367
+ freeze_feature_encoder=model_args.freeze_feature_encoder,
1368
+ train=True,
1369
+ )[0]
1370
+ logits_mask = state.get_attention_mask_fn(logits.shape[1], batch["attention_mask"])
1371
+ loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean")
1372
+
1373
+ return loss
1374
+
1375
+ grad_fn = jax.value_and_grad(compute_loss)
1376
+
1377
+ if gradient_accumulation_steps == 1 or training_args.multisteps:
1378
+ loss, grad = grad_fn(to_dtype(state.params), batch)
1379
+
1380
+ # Custom gradient accumulation
1381
+ else:
1382
+ # add a first dimension over gradient_accumulation_steps for minibatch slices
1383
+ batch = jax.tree_util.tree_map(
1384
+ lambda x: x.reshape(
1385
+ gradient_accumulation_steps, training_args.per_device_train_batch_size, *x.shape[1::]
1386
+ ),
1387
+ batch,
1388
+ )
1389
+
1390
+ def accum_minibatch_step(accum_grad, minibatch):
1391
+ # compute loss, num labels and grad over minibatch and accumulate
1392
+ loss, grad = grad_fn(to_dtype(state.params), minibatch)
1393
+ return jax.tree_util.tree_map(jnp.add, accum_grad, grad), loss
1394
+
1395
+ # create an initial state for accumulating losses, num labels and gradients
1396
+ init_grad = jax.tree_util.tree_map(jnp.zeros_like, to_dtype(state.params))
1397
+ # loop accum minibatch step over the number of gradient accumulation steps
1398
+ grad, loss = jax.lax.scan(accum_minibatch_step, init_grad, batch)
1399
+
1400
+ # update state
1401
+ new_state = state.apply_gradients(
1402
+ grads=grad,
1403
+ dropout_rng=new_dropout_rng,
1404
+ to_dtype=to_dtype,
1405
+ )
1406
+
1407
+ # compute gradient norms over all layers and globally for detailed monitoring
1408
+ layer_grad_norm = jax.tree_util.tree_map(jnp.linalg.norm, grad)
1409
+ logs = {
1410
+ "layer_grad_norm": layer_grad_norm,
1411
+ "grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm)),
1412
+ }
1413
+
1414
+ # compute parameter norms over all layers and globally for detailed monitoring
1415
+ layer_param_norm = jax.tree_util.tree_map(jnp.linalg.norm, new_state.params)
1416
+ logs["layer_param_norm"] = layer_param_norm
1417
+ logs["param_norm"] = jnp.linalg.norm(jax.tree_util.tree_leaves(layer_param_norm))
1418
+
1419
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
1420
+ metrics.update(logs)
1421
+
1422
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1423
+ # metrics = to_fp32(metrics)
1424
+
1425
+ return new_state, metrics
1426
+
1427
+ # Define eval fn
1428
+ def eval_step(params, batch):
1429
+ labels = batch.pop("labels")
1430
+ logits = model(**batch, params=params, train=False)[0]
1431
+
1432
+ logits_mask = model._get_feature_vector_attention_mask(logits.shape[1], batch["attention_mask"])
1433
+ loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean")
1434
+
1435
+ pred_ids = jnp.argmax(logits, axis=-1)
1436
+
1437
+ # summarize metrics
1438
+ metrics = {"loss": loss}
1439
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1440
+ # metrics = to_fp32(metrics)
1441
+ return metrics, pred_ids
1442
+
1443
+ # Create parallel version of the train and eval step
1444
+ if training_args.do_train:
1445
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
1446
+
1447
+ if training_args.do_eval:
1448
+ p_eval_step = jax.pmap(eval_step, "batch")
1449
+
1450
+ def run_evaluation(step):
1451
+ if training_args.do_eval:
1452
+ # ======================== Evaluating ==============================
1453
+ eval_metrics = []
1454
+ eval_preds = []
1455
+ eval_labels = []
1456
+
1457
+ # Generate eval set by sequentially sampling indices from the eval dataset and grouping by length
1458
+ eval_samples_idx = get_grouped_indices(vectorized_datasets[data_args.eval_split_name], eval_batch_size)
1459
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
1460
+
1461
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
1462
+ samples = [vectorized_datasets[data_args.eval_split_name][int(idx)] for idx in batch_idx]
1463
+ batch = data_collator(samples)
1464
+ labels = batch["labels"]
1465
+
1466
+ metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1467
+ eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1])))
1468
+ eval_metrics.append(metrics)
1469
+
1470
+ eval_labels.extend(labels)
1471
+
1472
+ # normalize eval metrics
1473
+ eval_metrics = get_metrics(eval_metrics)
1474
+ eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
1475
+ eval_metrics = to_fp32(eval_metrics)
1476
+
1477
+ # always run compute metrics
1478
+ error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels)
1479
+ eval_metrics.update(error_rate_metric)
1480
+ error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()])
1481
+
1482
+ # Print metrics and update progress bar
1483
+ desc = f"Step... ({step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})"
1484
+ epochs.write(desc)
1485
+ epochs.desc = desc
1486
+
1487
+ # Save metrics
1488
+ write_wandb_log(eval_metrics, step, prefix="eval")
1489
+ write_wandb_pred(pred_str, label_str, step)
1490
+ # if has_tensorboard and jax.process_index() == 0:
1491
+ # write_eval_metric(summary_writer, eval_metrics, step, pred_str=pred_str)
1492
+
1493
+ def save_checkpoint(step):
1494
+ # save and push checkpoint to the hub
1495
+ if jax.process_index() == 0:
1496
+ params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
1497
+ model.save_pretrained(training_args.output_dir, params=params)
1498
+ tokenizer.save_pretrained(training_args.output_dir)
1499
+ if training_args.push_to_hub:
1500
+ repo.push_to_hub(commit_message=f"{wandb.run.id}: saving weights and logs of step {int(step / 1000)}k", blocking=False)
1501
+
1502
+ skip_epochs = data_args.skip_steps // (num_train_samples // batch_size_per_update)
1503
+ logger.info("***** Running training *****")
1504
+ logger.info(f" Num examples = {num_train_samples}")
1505
+ logger.info(f" Num Epochs = {num_epochs}")
1506
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
1507
+ logger.info(f" Num gradient accumulation steps = {gradient_accumulation_steps}")
1508
+ logger.info(f" Total train batch size (w. parallel & distributed) = {batch_size_per_update}")
1509
+ logger.info(f" Total optimization steps = {total_train_steps}")
1510
+ logger.info(f" Gradient checkpointing: {config.gradient_checkpointing}")
1511
+ logger.info(f" Use scan: {config.use_scan}")
1512
+ logger.info(f" Fuse matmuls: {config.fuse_matmuls}")
1513
+ logger.info(f" Skipping: {data_args.skip_steps} steps ({skip_epochs} epochs)")
1514
+
1515
+ train_time = cur_step = 0
1516
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
1517
+ for epoch in epochs:
1518
+ if training_args.do_train:
1519
+ # ======================== Training ================================
1520
+ train_start = time.time()
1521
+
1522
+ if epoch < skip_epochs:
1523
+ logger.info(f"Skipping epoch {epoch + 1}")
1524
+ # Create sampling rng
1525
+ rng, input_rng = jax.random.split(rng)
1526
+ continue
1527
+
1528
+ # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length
1529
+ train_samples_idx = get_grouped_indices(vectorized_datasets[data_args.train_split_name], batch_size_per_update, input_rng)
1530
+ train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update)
1531
+
1532
+ if data_args.skip_steps > cur_step:
1533
+ logger.info(f"Skipping {data_args.skip_steps - (epoch * (num_train_samples // batch_size_per_update))} steps...")
1534
+ # Gather the indices for creating the batch and do a training step
1535
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1), 1):
1536
+ cur_step = epoch * (num_train_samples // batch_size_per_update) + step
1537
+ if cur_step <= data_args.skip_steps:
1538
+ continue
1539
+
1540
+ samples = [vectorized_datasets[data_args.train_split_name][int(idx)] for idx in batch_idx]
1541
+ batch = data_collator(samples)
1542
+ batch = shard(batch.data)
1543
+ try:
1544
+ state, train_metric = p_train_step(state, batch)
1545
+ except TypeError as e:
1546
+ logger.warning("Encountered following error: \n", e)
1547
+
1548
+
1549
+ if cur_step % training_args.logging_steps == 0:
1550
+ # Save metrics
1551
+ train_metric = unreplicate(train_metric)
1552
+ train_time += time.time() - train_start
1553
+ # need to upcast all device arrays to fp32 for wandb logging (jnp.bfloat16 not supported) -> do this here OR in train_step
1554
+ write_wandb_log(to_fp32(train_metric), cur_step, prefix=data_args.train_split_name)
1555
+ # we won't log to tensorboard for now (it is fiddly logging param and grad norms on a layer-by-layer basis)
1556
+ # if has_tensorboard and jax.process_index() == 0:
1557
+ # write_train_metric(summary_writer, train_metrics, train_time, cur_step)
1558
+
1559
+ epochs.write(
1560
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']}, Gradient Norm: {train_metric['grad_norm']})"
1561
+ )
1562
+ p_train_step.clear_cache()
1563
+
1564
+ if cur_step % total_train_steps == 0:
1565
+ break
1566
+
1567
+ if training_args.eval_steps and cur_step % training_args.eval_steps == 0:
1568
+ run_evaluation(cur_step)
1569
+
1570
+ if cur_step % training_args.save_steps == 0:
1571
+ save_checkpoint(cur_step)
1572
+
1573
+ if training_args.eval_steps == 0 and (epoch + 1) != num_epochs:
1574
+ # run evaluation at the end of the epoch if eval steps are not specified
1575
+ run_evaluation(cur_step)
1576
+ save_checkpoint(cur_step)
1577
+
1578
+ if training_args.do_train:
1579
+ save_checkpoint(cur_step)
1580
+
1581
+ cur_step = max_steps if max_steps > 0 else cur_step # set step to max steps so that eval happens in alignment with training
1582
+
1583
+ if training_args.do_eval:
1584
+ run_evaluation(cur_step)
1585
+
1586
+ # TODO: collapse 'do_predict' into the run_evaluation function
1587
+ if training_args.do_predict:
1588
+ for split in [data_args.test_split_name]:
1589
+ # ======================== Evaluating ==============================
1590
+ eval_metrics = []
1591
+ eval_preds = []
1592
+ eval_labels = []
1593
+
1594
+ # Generate eval set by sequentially sampling indices from the test dataset and grouping by length
1595
+ eval_samples_idx = get_grouped_indices(vectorized_datasets[split], eval_batch_size)
1596
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
1597
+
1598
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc=f"Predicting {split}...", position=2)):
1599
+ samples = [vectorized_datasets[split][int(idx)] for idx in batch_idx]
1600
+ batch = data_collator(samples)
1601
+ labels = batch["labels"]
1602
+
1603
+ metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1604
+ eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1])))
1605
+ eval_metrics.append(metrics)
1606
+
1607
+ eval_labels.extend(labels)
1608
+
1609
+ # normalize eval metrics
1610
+ eval_metrics = get_metrics(eval_metrics)
1611
+ eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
1612
+ eval_metrics = to_fp32(eval_metrics)
1613
+
1614
+ # always run compute metrics
1615
+ error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels)
1616
+ eval_metrics.update(error_rate_metric)
1617
+ error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()])
1618
+
1619
+ # Print metrics and update progress bar
1620
+ desc = f"Step... ({cur_step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})"
1621
+ epochs.write(desc)
1622
+ epochs.desc = desc
1623
+
1624
+ # Save metrics
1625
+ write_wandb_log(eval_metrics, cur_step, prefix=split)
1626
+ write_wandb_pred(pred_str, label_str, cur_step, prefix=split)
1627
+ # if has_tensorboard and jax.process_index() == 0:
1628
+ # write_eval_metric(summary_writer, eval_metrics, cur_step, pred_str=pred_str)
1629
+
1630
+
1631
+ if __name__ == "__main__":
1632
+ main()
wandb/run-20220810_151736-2jo5la5b/files/config.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_version: 1
2
+
3
+ _wandb:
4
+ desc: null
5
+ value:
6
+ cli_version: 0.12.9
7
+ code_path: code/run_flax_speech_recognition_ctc.py
8
+ framework: huggingface
9
+ huggingface_version: 4.21.0
10
+ is_jupyter_run: false
11
+ is_kaggle_kernel: false
12
+ python_version: 3.8.10
13
+ start_time: 1660144656
14
+ t:
15
+ 1:
16
+ - 1
17
+ - 2
18
+ - 3
19
+ - 11
20
+ - 12
21
+ 2:
22
+ - 1
23
+ - 2
24
+ - 3
25
+ - 11
26
+ - 12
27
+ 3:
28
+ - 13
29
+ 4: 3.8.10
30
+ 5: 0.12.9
31
+ 6: 4.21.0
32
+ 8:
33
+ - 5
wandb/run-20220810_151736-2jo5la5b/files/diff.patch ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/run.recover.sh b/run.recover.sh
2
+ index 77ad3fd..632a336 100755
3
+ --- a/run.recover.sh
4
+ +++ b/run.recover.sh
5
+ @@ -11,9 +11,8 @@ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_c
6
+ --per_device_train_batch_size="2" \
7
+ --per_device_eval_batch_size="2" \
8
+ --gradient_accumulation_steps="1" \
9
+ - --precision="full_mixed" \
10
+ + --precision="half_mixed" \
11
+ --matmul_precision="bfloat16" \
12
+ - --multisteps \
13
+ --learning_rate="6.394633237505332e-05" \
14
+ --skip_steps="275000" \
15
+ --warmup_steps="2000" \
16
+ diff --git a/run_flax_speech_recognition_ctc.py b/run_flax_speech_recognition_ctc.py
17
+ index a330879..4d0f5fc 100644
18
+ --- a/run_flax_speech_recognition_ctc.py
19
+ +++ b/run_flax_speech_recognition_ctc.py
20
+ @@ -415,12 +415,12 @@ class MixedPrecisionTrainState(struct.PyTreeNode):
21
+ )
22
+
23
+ @classmethod
24
+ - def create(cls, *, apply_fn, params, tx, to_dtype, **kwargs):
25
+ + def create(cls, *, apply_fn, params, tx, to_dtype, step=0, **kwargs):
26
+ """Creates a new instance with `step=0` and initialized `opt_state`."""
27
+ # downcast optimizer state to bf16 if mixed-precision training
28
+ opt_state = tx.init(to_dtype(params)) if tx is not None else None
29
+ return cls(
30
+ - step=0,
31
+ + step=step,
32
+ apply_fn=apply_fn,
33
+ params=params,
34
+ tx=tx,
35
+ @@ -1339,6 +1339,7 @@ def main():
36
+
37
+ # Setup train state
38
+ state = MixedPrecisionTrainState.create(
39
+ + step=data_args.skip_steps,
40
+ apply_fn=model.__call__,
41
+ get_attention_mask_fn=model._get_feature_vector_attention_mask,
42
+ params=model.params,
43
+ @@ -1520,11 +1521,10 @@ def main():
44
+
45
+ if epoch < skip_epochs:
46
+ logger.info(f"Skipping epoch {epoch + 1}")
47
+ + # Create sampling rng
48
+ + rng, input_rng = jax.random.split(rng)
49
+ continue
50
+
51
+ - # Create sampling rng
52
+ - rng, input_rng = jax.random.split(rng)
53
+ -
54
+ # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length
55
+ train_samples_idx = get_grouped_indices(vectorized_datasets[data_args.train_split_name], batch_size_per_update, input_rng)
56
+ train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update)
57
+ @@ -1559,6 +1559,7 @@ def main():
58
+ epochs.write(
59
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']}, Gradient Norm: {train_metric['grad_norm']})"
60
+ )
61
+ + p_train_step.clear_cache()
62
+
63
+ if cur_step % total_train_steps == 0:
64
+ break
65
+ diff --git a/special_tokens_map.json b/special_tokens_map.json
66
+ index 218961f..3c0d148 100644
67
+ --- a/special_tokens_map.json
68
+ +++ b/special_tokens_map.json
69
+ @@ -399,6 +399,48 @@
70
+ "rstrip": false,
71
+ "single_word": false
72
+ },
73
+ + {
74
+ + "content": "</s>",
75
+ + "lstrip": false,
76
+ + "normalized": true,
77
+ + "rstrip": false,
78
+ + "single_word": false
79
+ + },
80
+ + {
81
+ + "content": "<s>",
82
+ + "lstrip": false,
83
+ + "normalized": true,
84
+ + "rstrip": false,
85
+ + "single_word": false
86
+ + },
87
+ + {
88
+ + "content": "</s>",
89
+ + "lstrip": false,
90
+ + "normalized": true,
91
+ + "rstrip": false,
92
+ + "single_word": false
93
+ + },
94
+ + {
95
+ + "content": "<s>",
96
+ + "lstrip": false,
97
+ + "normalized": true,
98
+ + "rstrip": false,
99
+ + "single_word": false
100
+ + },
101
+ + {
102
+ + "content": "</s>",
103
+ + "lstrip": false,
104
+ + "normalized": true,
105
+ + "rstrip": false,
106
+ + "single_word": false
107
+ + },
108
+ + {
109
+ + "content": "<s>",
110
+ + "lstrip": false,
111
+ + "normalized": true,
112
+ + "rstrip": false,
113
+ + "single_word": false
114
+ + },
115
+ {
116
+ "content": "</s>",
117
+ "lstrip": false,
118
+ diff --git a/wandb/debug-internal.log b/wandb/debug-internal.log
119
+ index 23926ef..90a074d 120000
120
+ --- a/wandb/debug-internal.log
121
+ +++ b/wandb/debug-internal.log
122
+ @@ -1 +1 @@
123
+ -run-20220805_230151-2y71vcu4/logs/debug-internal.log
124
+
125
+ +run-20220810_151736-2jo5la5b/logs/debug-internal.log
126
+
127
+ diff --git a/wandb/debug.log b/wandb/debug.log
128
+ index 279853d..de899a6 120000
129
+ --- a/wandb/debug.log
130
+ +++ b/wandb/debug.log
131
+ @@ -1 +1 @@
132
+ -run-20220805_230151-2y71vcu4/logs/debug.log
133
+
134
+ +run-20220810_151736-2jo5la5b/logs/debug.log
135
+
136
+ diff --git a/wandb/latest-run b/wandb/latest-run
137
+ index f069a7a..0dfb7e0 120000
138
+ --- a/wandb/latest-run
139
+ +++ b/wandb/latest-run
140
+ @@ -1 +1 @@
141
+ -run-20220805_230151-2y71vcu4
142
+
143
+ +run-20220810_151736-2jo5la5b
144
+
wandb/run-20220810_151736-2jo5la5b/files/output.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:89c3976bd22db27a53b27fd3d63f60a6563116c92d804aceb8ada0bf7909833f
3
+ size 224905
wandb/run-20220810_151736-2jo5la5b/files/requirements.txt ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.0.0
2
+ aiohttp==3.8.1
3
+ aiosignal==1.2.0
4
+ appdirs==1.4.4
5
+ astunparse==1.6.3
6
+ async-timeout==4.0.2
7
+ attrs==21.4.0
8
+ audioread==2.1.9
9
+ backcall==0.2.0
10
+ cachetools==4.2.4
11
+ certifi==2021.10.8
12
+ cffi==1.15.1
13
+ charset-normalizer==2.0.10
14
+ chex==0.1.3
15
+ click==8.0.3
16
+ cloud-tpu-client==0.10
17
+ cloud-tpu-profiler==2.4.0
18
+ clu==0.0.6
19
+ colorama==0.4.5
20
+ commonmark==0.9.1
21
+ configparser==5.2.0
22
+ contextlib2==21.6.0
23
+ cycler==0.11.0
24
+ datasets==2.4.0
25
+ decorator==5.1.0
26
+ dill==0.3.4
27
+ dm-tree==0.1.6
28
+ docker-pycreds==0.4.0
29
+ etils==0.6.0
30
+ exceptiongroup==1.0.0rc8
31
+ filelock==3.4.2
32
+ flatbuffers==2.0
33
+ flax==0.5.3
34
+ fonttools==4.28.5
35
+ frozenlist==1.2.0
36
+ fsspec==2021.11.1
37
+ future==0.18.2
38
+ gast==0.4.0
39
+ gitdb==4.0.9
40
+ gitpython==3.1.26
41
+ google-api-core==1.31.5
42
+ google-api-python-client==1.8.0
43
+ google-auth-httplib2==0.1.0
44
+ google-auth-oauthlib==0.4.6
45
+ google-auth==2.3.3
46
+ google-pasta==0.2.0
47
+ googleapis-common-protos==1.54.0
48
+ grpcio==1.43.0
49
+ h5py==3.6.0
50
+ httplib2==0.20.2
51
+ huggingface-hub==0.2.1
52
+ hypothesis==6.53.0
53
+ idna==3.3
54
+ importlib-metadata==4.10.0
55
+ importlib-resources==5.4.0
56
+ ipython==7.31.0
57
+ jax==0.3.15
58
+ jaxlib==0.3.15
59
+ jedi==0.18.1
60
+ jiwer==2.3.0
61
+ joblib==1.1.0
62
+ keras-preprocessing==1.1.2
63
+ keras==2.7.0
64
+ kiwisolver==1.3.2
65
+ libclang==12.0.0
66
+ librosa==0.9.2
67
+ libtpu-nightly==0.1.dev20220722
68
+ llvmlite==0.39.0
69
+ markdown==3.3.6
70
+ matplotlib-inline==0.1.3
71
+ matplotlib==3.5.1
72
+ ml-collections==0.1.0
73
+ msgpack==1.0.3
74
+ multidict==5.2.0
75
+ multiprocess==0.70.12.2
76
+ numba==0.56.0
77
+ numpy==1.22.0
78
+ oauth2client==4.1.3
79
+ oauthlib==3.1.1
80
+ opt-einsum==3.3.0
81
+ optax==0.1.3
82
+ packaging==21.3
83
+ pandas==1.3.5
84
+ parso==0.8.3
85
+ pathtools==0.1.2
86
+ pexpect==4.8.0
87
+ pickleshare==0.7.5
88
+ pillow==9.0.0
89
+ pip==22.2.2
90
+ pkg-resources==0.0.0
91
+ pooch==1.6.0
92
+ promise==2.3
93
+ prompt-toolkit==3.0.24
94
+ protobuf==3.19.1
95
+ psutil==5.9.0
96
+ ptyprocess==0.7.0
97
+ pyarrow==6.0.1
98
+ pyasn1-modules==0.2.8
99
+ pyasn1==0.4.8
100
+ pycparser==2.21
101
+ pyctcdecode==0.4.0
102
+ pygments==2.11.1
103
+ pygtrie==2.5.0
104
+ pyparsing==3.0.6
105
+ python-dateutil==2.8.2
106
+ python-levenshtein==0.12.2
107
+ pytz==2021.3
108
+ pyyaml==6.0
109
+ regex==2021.11.10
110
+ requests-oauthlib==1.3.0
111
+ requests==2.27.0
112
+ resampy==0.3.1
113
+ responses==0.18.0
114
+ rich==11.2.0
115
+ rsa==4.8
116
+ sacremoses==0.0.46
117
+ scikit-learn==1.1.1
118
+ scipy==1.7.3
119
+ sentry-sdk==1.5.2
120
+ setuptools==44.0.0
121
+ shortuuid==1.0.8
122
+ six==1.16.0
123
+ smmap==5.0.0
124
+ sortedcontainers==2.4.0
125
+ soundfile==0.10.3.post1
126
+ sox==1.4.1
127
+ subprocess32==3.5.4
128
+ tensorboard-data-server==0.6.1
129
+ tensorboard-plugin-wit==1.8.0
130
+ tensorboard==2.7.0
131
+ tensorflow-cpu==2.7.0
132
+ tensorflow-datasets==4.4.0
133
+ tensorflow-estimator==2.7.0
134
+ tensorflow-io-gcs-filesystem==0.23.1
135
+ tensorflow-metadata==1.5.0
136
+ tensorflow==2.7.0
137
+ tensorstore==0.1.21
138
+ termcolor==1.1.0
139
+ threadpoolctl==3.1.0
140
+ tokenizers==0.11.2
141
+ toolz==0.11.2
142
+ torch==1.12.0
143
+ torchaudio==0.12.0+cpu
144
+ tqdm==4.62.3
145
+ traitlets==5.1.1
146
+ transformers==4.21.0
147
+ typing-extensions==4.3.0
148
+ uritemplate==3.0.1
149
+ urllib3==1.26.7
150
+ wandb==0.12.9
151
+ wcwidth==0.2.5
152
+ werkzeug==2.0.2
153
+ wheel==0.37.1
154
+ wrapt==1.13.3
155
+ xxhash==2.0.2
156
+ yarl==1.7.2
157
+ yaspin==2.1.0
158
+ zipp==3.7.0