versae commited on
Commit
4b97d7b
1 Parent(s): e2b1320

2y71vcu4: saving weights and logs of step 5k

Browse files
Files changed (50) hide show
  1. events.out.tfevents.1659700313.t1v-n-eedfb410-w-0.3572327.0.v2 +3 -0
  2. events.out.tfevents.1659704423.t1v-n-eedfb410-w-0.2632364.0.v2 +3 -0
  3. events.out.tfevents.1659713857.t1v-n-eedfb410-w-0.1781451.0.v2 +3 -0
  4. events.out.tfevents.1659741255.t1v-n-eedfb410-w-0.825506.0.v2 +3 -0
  5. flax_model.msgpack +1 -1
  6. run.sh +5 -5
  7. special_tokens_map.json +56 -0
  8. wandb/debug-internal.log +1 -1
  9. wandb/debug.log +1 -1
  10. wandb/latest-run +1 -1
  11. wandb/run-20220805_114003-3f1x0hvu/files/code/run_flax_speech_recognition_ctc.py +1631 -0
  12. wandb/run-20220805_114003-3f1x0hvu/files/config.yaml +33 -0
  13. wandb/run-20220805_114003-3f1x0hvu/files/diff.patch +53 -0
  14. wandb/run-20220805_114003-3f1x0hvu/files/output.log +3 -0
  15. wandb/run-20220805_114003-3f1x0hvu/files/requirements.txt +158 -0
  16. wandb/run-20220805_114003-3f1x0hvu/files/wandb-metadata.json +69 -0
  17. wandb/run-20220805_114003-3f1x0hvu/files/wandb-summary.json +1 -0
  18. wandb/run-20220805_114003-3f1x0hvu/logs/debug-internal.log +3 -0
  19. wandb/run-20220805_114003-3f1x0hvu/logs/debug.log +3 -0
  20. wandb/run-20220805_114003-3f1x0hvu/run-3f1x0hvu.wandb +3 -0
  21. wandb/run-20220805_124834-3ep5xqhh/files/code/run_flax_speech_recognition_ctc.py +1631 -0
  22. wandb/run-20220805_124834-3ep5xqhh/files/config.yaml +33 -0
  23. wandb/run-20220805_124834-3ep5xqhh/files/diff.patch +105 -0
  24. wandb/run-20220805_124834-3ep5xqhh/files/output.log +3 -0
  25. wandb/run-20220805_124834-3ep5xqhh/files/requirements.txt +158 -0
  26. wandb/run-20220805_124834-3ep5xqhh/files/wandb-metadata.json +69 -0
  27. wandb/run-20220805_124834-3ep5xqhh/files/wandb-summary.json +1 -0
  28. wandb/run-20220805_124834-3ep5xqhh/logs/debug-internal.log +3 -0
  29. wandb/run-20220805_124834-3ep5xqhh/logs/debug.log +3 -0
  30. wandb/run-20220805_124834-3ep5xqhh/run-3ep5xqhh.wandb +3 -0
  31. wandb/run-20220805_152536-2fzkf8n5/files/code/run_flax_speech_recognition_ctc.py +1631 -0
  32. wandb/run-20220805_152536-2fzkf8n5/files/config.yaml +33 -0
  33. wandb/run-20220805_152536-2fzkf8n5/files/diff.patch +119 -0
  34. wandb/run-20220805_152536-2fzkf8n5/files/output.log +3 -0
  35. wandb/run-20220805_152536-2fzkf8n5/files/requirements.txt +158 -0
  36. wandb/run-20220805_152536-2fzkf8n5/files/wandb-metadata.json +69 -0
  37. wandb/run-20220805_152536-2fzkf8n5/files/wandb-summary.json +1 -0
  38. wandb/run-20220805_152536-2fzkf8n5/logs/debug-internal.log +3 -0
  39. wandb/run-20220805_152536-2fzkf8n5/logs/debug.log +3 -0
  40. wandb/run-20220805_152536-2fzkf8n5/run-2fzkf8n5.wandb +3 -0
  41. wandb/run-20220805_230151-2y71vcu4/files/code/run_flax_speech_recognition_ctc.py +1631 -0
  42. wandb/run-20220805_230151-2y71vcu4/files/config.yaml +27 -0
  43. wandb/run-20220805_230151-2y71vcu4/files/diff.patch +131 -0
  44. wandb/run-20220805_230151-2y71vcu4/files/output.log +3 -0
  45. wandb/run-20220805_230151-2y71vcu4/files/requirements.txt +158 -0
  46. wandb/run-20220805_230151-2y71vcu4/files/wandb-metadata.json +69 -0
  47. wandb/run-20220805_230151-2y71vcu4/files/wandb-summary.json +1 -0
  48. wandb/run-20220805_230151-2y71vcu4/logs/debug-internal.log +3 -0
  49. wandb/run-20220805_230151-2y71vcu4/logs/debug.log +3 -0
  50. wandb/run-20220805_230151-2y71vcu4/run-2y71vcu4.wandb +3 -0
events.out.tfevents.1659700313.t1v-n-eedfb410-w-0.3572327.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c97be7c1da6944777366fdb9b6a215f817d3404702de43e7cb635c65bce877c3
3
+ size 40
events.out.tfevents.1659704423.t1v-n-eedfb410-w-0.2632364.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9acb1bfa25db8551780fabc1f10392872609619a94b58bb4d3da1dc556c4f882
3
+ size 40
events.out.tfevents.1659713857.t1v-n-eedfb410-w-0.1781451.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f7be4e8f9f2175b6c57c706283fc3d06d821d847cce440f5ed5561f91c721a68
3
+ size 40
events.out.tfevents.1659741255.t1v-n-eedfb410-w-0.825506.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea9456621b28d25411e36f9a7e20faa6862a27c9b3165c65e01f70413e35ae2f
3
+ size 40
flax_model.msgpack CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c0dfd76c4de572f7ef0d3b62ef8e1512bb9c3814cdffd8f9e8ae54845bd9dc9d
3
  size 3850218852
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:635cdc48780adc608e330d5ba666a9ea24d6f5425496bf993f1286ead4a35623
3
  size 3850218852
run.sh CHANGED
@@ -1,6 +1,6 @@
1
  WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_ctc.py \
2
  --model_name_or_path="facebook/wav2vec2-xls-r-1b" \
3
- --hub_model_id="NbAiLab/wav2vec2-1b-npsc-nst" \
4
  --tokenizer_name="./" \
5
  --output_dir="./" \
6
  --overwrite_output_dir \
@@ -11,13 +11,13 @@ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_c
11
  --precision="full_mixed" \
12
  --matmul_precision="bfloat16" \
13
  --multisteps \
14
- --learning_rate="2e-4" \
15
  --warmup_steps="2000" \
16
  --length_column_name="input_length" \
17
  --evaluation_strategy="steps" \
18
  --text_column_name="text" \
19
- --save_steps="4000" \
20
- --eval_steps="4000" \
21
  --logging_steps="100" \
22
  --layerdrop="0.041" \
23
  --attention_dropout="0.094" \
@@ -42,7 +42,7 @@ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_c
42
  --ctc_zero_infinity \
43
  --do_lower_case \
44
  --wandb_project="wav2vec2" \
45
- --wandb_name="wav2vec2-1b-npsc-nst" \
46
  --remove_punctuation
47
 
48
 
 
1
  WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_ctc.py \
2
  --model_name_or_path="facebook/wav2vec2-xls-r-1b" \
3
+ --hub_model_id="NbAiLab/wav2vec2-1b-npsc-nst-tpu" \
4
  --tokenizer_name="./" \
5
  --output_dir="./" \
6
  --overwrite_output_dir \
 
11
  --precision="full_mixed" \
12
  --matmul_precision="bfloat16" \
13
  --multisteps \
14
+ --learning_rate="1e-4" \
15
  --warmup_steps="2000" \
16
  --length_column_name="input_length" \
17
  --evaluation_strategy="steps" \
18
  --text_column_name="text" \
19
+ --save_steps="5000" \
20
+ --eval_steps="5000" \
21
  --logging_steps="100" \
22
  --layerdrop="0.041" \
23
  --attention_dropout="0.094" \
 
42
  --ctc_zero_infinity \
43
  --do_lower_case \
44
  --wandb_project="wav2vec2" \
45
+ --wandb_name="wav2vec2-1b-npsc-nst-tpu" \
46
  --remove_punctuation
47
 
48
 
special_tokens_map.json CHANGED
@@ -343,6 +343,62 @@
343
  "rstrip": false,
344
  "single_word": false
345
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  {
347
  "content": "</s>",
348
  "lstrip": false,
 
343
  "rstrip": false,
344
  "single_word": false
345
  },
346
+ {
347
+ "content": "</s>",
348
+ "lstrip": false,
349
+ "normalized": true,
350
+ "rstrip": false,
351
+ "single_word": false
352
+ },
353
+ {
354
+ "content": "<s>",
355
+ "lstrip": false,
356
+ "normalized": true,
357
+ "rstrip": false,
358
+ "single_word": false
359
+ },
360
+ {
361
+ "content": "</s>",
362
+ "lstrip": false,
363
+ "normalized": true,
364
+ "rstrip": false,
365
+ "single_word": false
366
+ },
367
+ {
368
+ "content": "<s>",
369
+ "lstrip": false,
370
+ "normalized": true,
371
+ "rstrip": false,
372
+ "single_word": false
373
+ },
374
+ {
375
+ "content": "</s>",
376
+ "lstrip": false,
377
+ "normalized": true,
378
+ "rstrip": false,
379
+ "single_word": false
380
+ },
381
+ {
382
+ "content": "<s>",
383
+ "lstrip": false,
384
+ "normalized": true,
385
+ "rstrip": false,
386
+ "single_word": false
387
+ },
388
+ {
389
+ "content": "</s>",
390
+ "lstrip": false,
391
+ "normalized": true,
392
+ "rstrip": false,
393
+ "single_word": false
394
+ },
395
+ {
396
+ "content": "<s>",
397
+ "lstrip": false,
398
+ "normalized": true,
399
+ "rstrip": false,
400
+ "single_word": false
401
+ },
402
  {
403
  "content": "</s>",
404
  "lstrip": false,
wandb/debug-internal.log CHANGED
@@ -1 +1 @@
1
- run-20220803_091109-yit1e59z/logs/debug-internal.log
 
1
+ run-20220805_230151-2y71vcu4/logs/debug-internal.log
wandb/debug.log CHANGED
@@ -1 +1 @@
1
- run-20220803_091109-yit1e59z/logs/debug.log
 
1
+ run-20220805_230151-2y71vcu4/logs/debug.log
wandb/latest-run CHANGED
@@ -1 +1 @@
1
- run-20220803_091109-yit1e59z
 
1
+ run-20220805_230151-2y71vcu4
wandb/run-20220805_114003-3f1x0hvu/files/code/run_flax_speech_recognition_ctc.py ADDED
@@ -0,0 +1,1631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2022 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the Flax library models for connectionist temporal classification (CTC) speech recognition.
18
+ """
19
+ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
+
21
+ import logging
22
+ import math
23
+ import os
24
+ import re
25
+ import sys
26
+ import time
27
+ from dataclasses import dataclass, field
28
+ from pathlib import Path
29
+ from typing import Any, Callable, Dict, List, Optional, Union
30
+
31
+ import datasets
32
+ import numpy as np
33
+ from datasets import DatasetDict, load_dataset, load_metric
34
+ from tqdm import tqdm
35
+
36
+ import flax
37
+ import jax
38
+ import jax.numpy as jnp
39
+ import optax
40
+ import transformers
41
+ import wandb as wandb
42
+ from flax import core, jax_utils, struct, traverse_util
43
+ from flax.jax_utils import unreplicate, pad_shard_unpad
44
+ from flax.training.common_utils import get_metrics, shard, shard_prng_key
45
+ from huggingface_hub import Repository
46
+ from models import Wav2Vec2Config, FlaxWav2Vec2ForCTC
47
+ from optax._src import linear_algebra
48
+ from transformers import (
49
+ AutoFeatureExtractor,
50
+ AutoProcessor,
51
+ AutoTokenizer,
52
+ HfArgumentParser,
53
+ TrainingArguments,
54
+ is_tensorboard_available,
55
+ set_seed,
56
+ )
57
+ from transformers.file_utils import get_full_repo_name
58
+ from transformers.utils import check_min_version
59
+ from transformers.utils.versions import require_version
60
+
61
+
62
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
63
+ check_min_version("4.17.0.dev0")
64
+
65
+ require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
66
+
67
+ logger = logging.getLogger(__name__)
68
+
69
+
70
+ @flax.struct.dataclass
71
+ class ModelArguments:
72
+ """
73
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
74
+ """
75
+
76
+ model_name_or_path: str = field(
77
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
78
+ )
79
+ config_name: Optional[str] = field(
80
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
81
+ )
82
+ tokenizer_name: Optional[str] = field(
83
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
84
+ )
85
+ feature_extractor_name: Optional[str] = field(
86
+ default=None, metadata={"help": "feature extractor name or path if not the same as model_name"}
87
+ )
88
+ cache_dir: Optional[str] = field(
89
+ default=None,
90
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
91
+ )
92
+ use_fast_tokenizer: bool = field(
93
+ default=True,
94
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
95
+ )
96
+ model_revision: str = field(
97
+ default="main",
98
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
99
+ )
100
+ use_auth_token: bool = field(
101
+ default=False,
102
+ metadata={
103
+ "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
104
+ "with private models)."
105
+ },
106
+ )
107
+ freeze_feature_encoder: bool = field(
108
+ default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
109
+ )
110
+ attention_dropout: float = field(
111
+ default=0.0, metadata={"help": "The dropout ratio for the attention probabilities."}
112
+ )
113
+ activation_dropout: float = field(
114
+ default=0.1,
115
+ metadata={
116
+ "help": "The hidden activation dropout probability in the embeddings, encoder, and pooler."
117
+ },
118
+ )
119
+ hidden_dropout: float = field(
120
+ default=0.1,
121
+ metadata={
122
+ "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler."
123
+ },
124
+ )
125
+ feat_proj_dropout: float = field(
126
+ default=0.0,
127
+ metadata={
128
+ "help": "The feat proj dropout probability for feature encoder representations."
129
+ },
130
+ )
131
+ final_dropout: float = field(
132
+ default=0.0,
133
+ metadata={"help": "The dropout probability for the final projection layer."},
134
+ )
135
+ mask_time_prob: float = field(
136
+ default=0.1,
137
+ metadata={
138
+ "help": "The spec aug dropout probability for feature encoder representations."
139
+ },
140
+ )
141
+ mask_time_length: int = field(
142
+ default=10,
143
+ metadata={"help": "Length of vector span to mask along the time axis."},
144
+ )
145
+ mask_feature_prob: float = field(
146
+ default=0.0,
147
+ metadata={
148
+ "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector"
149
+ "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis."
150
+ },
151
+ )
152
+ mask_feature_length: int = field(
153
+ default=10,
154
+ metadata={"help": "Length of vector span to mask along the feature axis."},
155
+ )
156
+ layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."})
157
+ ctc_loss_reduction: Optional[str] = field(
158
+ default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
159
+ )
160
+ ctc_zero_infinity: Optional[bool] = field(
161
+ default=False, metadata={"help": "If True, will try yo aboud the CTC loss goinf to infinity."}
162
+ )
163
+
164
+
165
+ @flax.struct.dataclass
166
+ class DataTrainingArguments:
167
+ """
168
+ Arguments pertaining to what data we are going to input our model for training and eval.
169
+ """
170
+
171
+ dataset_name: str = field(
172
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
173
+ )
174
+ dataset_config_name: Optional[str] = field(
175
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
176
+ )
177
+ text_column: Optional[str] = field(
178
+ default=None,
179
+ metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
180
+ )
181
+ dataset_cache_dir: Optional[str] = field(
182
+ default=None, metadata={"help": "Path to cache directory for saving and loading datasets"}
183
+ )
184
+ overwrite_cache: bool = field(
185
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
186
+ )
187
+ preprocessing_num_workers: Optional[int] = field(
188
+ default=None,
189
+ metadata={"help": "The number of processes to use for the preprocessing."},
190
+ )
191
+ max_train_samples: Optional[int] = field(
192
+ default=None,
193
+ metadata={
194
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
195
+ "value if set."
196
+ },
197
+ )
198
+ max_eval_samples: Optional[int] = field(
199
+ default=None,
200
+ metadata={
201
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
202
+ "value if set."
203
+ },
204
+ )
205
+ max_test_samples: Optional[int] = field(
206
+ default=None,
207
+ metadata={
208
+ "help": "For debugging purposes or quicker training, truncate the number of test examples to this "
209
+ "value if set."
210
+ },
211
+ )
212
+ audio_column_name: str = field(
213
+ default="audio",
214
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
215
+ )
216
+ text_column_name: str = field(
217
+ default="text",
218
+ metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
219
+ )
220
+ max_duration_in_seconds: float = field(
221
+ default=20.0,
222
+ metadata={
223
+ "help": "Truncate audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`"
224
+ },
225
+ )
226
+ min_duration_in_seconds: float = field(
227
+ default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
228
+ )
229
+ max_label_length: Optional[int] = field(
230
+ default=512,
231
+ metadata={
232
+ "help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
233
+ "than this will be filtered."
234
+ },
235
+ )
236
+ min_label_length: Optional[int] = field(
237
+ default=2,
238
+ metadata={
239
+ "help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
240
+ "than this will be filtered."
241
+ },
242
+ )
243
+ pad_input_to_multiple_of: Optional[int] = field(
244
+ default=32000,
245
+ metadata={
246
+ "help": "If set will pad the input sequence to a multiple of the provided value. "
247
+ "This is important to avoid triggering recompilations on TPU."
248
+ },
249
+ )
250
+ pad_target_to_multiple_of: Optional[int] = field(
251
+ default=None,
252
+ metadata={
253
+ "help": "If set will pad the target sequence to a multiple of the provided value. "
254
+ "This is important to avoid triggering recompilations on TPU."
255
+ },
256
+ )
257
+ preprocessing_only: bool = field(
258
+ default=False,
259
+ metadata={
260
+ "help": "Whether to only do data preprocessing and skip training. "
261
+ "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
262
+ "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
263
+ "so that the cached datasets can consequently be loaded in distributed training"
264
+ },
265
+ )
266
+ train_split_name: str = field(
267
+ default="train",
268
+ metadata={
269
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
270
+ },
271
+ )
272
+ eval_split_name: str = field(
273
+ default="validation",
274
+ metadata={
275
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
276
+ },
277
+ )
278
+ do_lower_case: bool = field(
279
+ default=True,
280
+ metadata={"help": "Whether the target text should be lower cased."},
281
+ )
282
+ wandb_project: str = field(
283
+ default="flax-speech-recognition-ctc",
284
+ metadata={"help": "The name of the wandb project."},
285
+ )
286
+ wandb_name: str = field(
287
+ default=None,
288
+ metadata={"help": "The name of the wandb run."},
289
+ )
290
+ wandb_job_type: str = field(
291
+ default="CTC",
292
+ metadata={"help": "The name of the wandb job type."},
293
+ )
294
+ test_split_name: str = field(
295
+ default="test",
296
+ metadata={"help": "The name of the test data set split to use (via the datasets library). Defaults to 'test'"},
297
+ )
298
+ remove_punctuation: bool = field(
299
+ default=False, metadata={"help": "Whether or not to remove punctuation during training."}
300
+ )
301
+ skip_steps: Optional[int] = field(
302
+ default=0,
303
+ metadata={
304
+ "help": "Skip this number of steps. Useful to continue training"
305
+ },
306
+ )
307
+
308
+
309
+ # @flax.struct.dataclass
310
+ @dataclass
311
+ class FlaxTrainingArguments(TrainingArguments):
312
+ precision: str = field(
313
+ default="full",
314
+ metadata={
315
+ "help": "Whether to enable mixed-precision training. If true, the optimizer is stored in half-precision (bfloat16) and computations are executed in half-precision"
316
+ "**Note that this only specifies the dtype of the computation and optimizer state. It does not influence the dtype of model parameters.**"
317
+ },
318
+ )
319
+ matmul_precision: str = field(
320
+ default="default",
321
+ metadata={
322
+ "help": "Default floating-point precision of internal computations used in TPU matrix multiplications and convolutions. "
323
+ "This configuration option controls the default precision for JAX operations that take an optional precision argument (e.g. `lax.conv_general_dilated` and `lax.dot`). "
324
+ "This configuration option does not change the behaviours of such calls with explicit precision arguments; "
325
+ "it only changes the behaviors of calls with no such argument provided. "
326
+ "One of `['highest', 'float32', 'high', 'bfloat16_3x', 'default', 'bfloat16', 'fastest', None]`."
327
+ },
328
+ )
329
+ multisteps: bool = field(
330
+ default=False,
331
+ metadata={
332
+ "help": "Whether to use Optax MultiSteps for gradient accumulation. If `False` (default) and `gradient_accumulation_steps > 1`, "
333
+ "a custom gradient accumulation implementation will be employed."
334
+ },
335
+ )
336
+
337
+
338
+ def to_fp32(t):
339
+ return jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
340
+
341
+
342
+ def to_bf16(t):
343
+ return jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t)
344
+
345
+
346
+ class MixedPrecisionTrainState(struct.PyTreeNode):
347
+ """Train state for use with a single Optax optimizer.
348
+ Adapted from flax train_state https://github.com/google/flax/blob/main/flax/training/train_state.py
349
+
350
+ Synopsis::
351
+
352
+ state = TrainState.create(
353
+ apply_fn=model.apply,
354
+ params=variables['params'],
355
+ tx=tx)
356
+ grad_fn = jax.grad(make_loss_fn(state.apply_fn))
357
+ for batch in data:
358
+ grads = grad_fn(state.params, batch)
359
+ state = state.apply_gradients(grads=grads)
360
+
361
+ Args:
362
+ step: Counter starts at 0 and is incremented by every call to
363
+ `.apply_gradients()`.
364
+ apply_fn: Usually set to `model.apply()`. Kept in this dataclass for
365
+ convenience to have a shorter params list for the `train_step()` function
366
+ in your training loop.
367
+ params: The parameters to be updated by `tx` and used by `apply_fn`.
368
+ tx: An Optax gradient transformation.
369
+ opt_state: The state for `tx`.
370
+ dropout_rng: PRNG key for stochastic operations.
371
+ bf16: Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training.
372
+ """
373
+
374
+ step: int
375
+ apply_fn: Callable = struct.field(pytree_node=False)
376
+ get_attention_mask_fn: Callable = struct.field(pytree_node=False)
377
+ params: core.FrozenDict[str, Any]
378
+ tx: optax.GradientTransformation = struct.field(pytree_node=False)
379
+ opt_state: optax.OptState
380
+ dropout_rng: jnp.ndarray
381
+ max_grad_norm: Optional[float] = 1.0
382
+
383
+ def apply_gradients(self, *, grads, to_dtype, **kwargs):
384
+ """Updates `step`, `params`, `opt_state` and `**kwargs` in return value.
385
+
386
+ Note that internally this function calls `.tx.update()` followed by a call
387
+ to `optax.apply_updates()` to update `params` and `opt_state`.
388
+
389
+ Args:
390
+ grads: Gradients that have the same pytree structure as `.params`.
391
+ **kwargs: Additional dataclass attributes that should be `.replace()`-ed.
392
+
393
+ Returns:
394
+ An updated instance of `self` with `step` incremented by one, `params`
395
+ and `opt_state` updated by applying `grads`, and additional attributes
396
+ replaced as specified by `kwargs`.
397
+ """
398
+
399
+ # clip gradients by global l2 norm
400
+ casted_max_grad_norm = to_dtype(self.max_grad_norm)
401
+ g_norm = linear_algebra.global_norm(grads)
402
+ g_norm = jnp.maximum(casted_max_grad_norm, g_norm)
403
+ grads = jax.tree_util.tree_map(lambda t: (t / g_norm) * casted_max_grad_norm, grads)
404
+
405
+ # perform update step in fp32 and subsequently downcast optimizer states if mixed precision training
406
+ # grads and opt_state in bf16 (need to upcast), params in fp32 (leave as is)
407
+ updates, new_opt_state = self.tx.update(to_fp32(grads), to_fp32(self.opt_state), self.params)
408
+
409
+ new_params = optax.apply_updates(self.params, updates)
410
+ return self.replace(
411
+ step=self.step + 1,
412
+ params=new_params,
413
+ opt_state=to_dtype(new_opt_state),
414
+ **kwargs,
415
+ )
416
+
417
+ @classmethod
418
+ def create(cls, *, apply_fn, params, tx, to_dtype, **kwargs):
419
+ """Creates a new instance with `step=0` and initialized `opt_state`."""
420
+ # downcast optimizer state to bf16 if mixed-precision training
421
+ opt_state = tx.init(to_dtype(params)) if tx is not None else None
422
+ return cls(
423
+ step=0,
424
+ apply_fn=apply_fn,
425
+ params=params,
426
+ tx=tx,
427
+ opt_state=opt_state,
428
+ **kwargs,
429
+ )
430
+
431
+ def replicate(self):
432
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
433
+
434
+
435
+ @flax.struct.dataclass
436
+ class FlaxDataCollatorSpeechSeq2SeqWithPadding:
437
+ """
438
+ Data collator that will dynamically pad the inputs received.
439
+ Args:
440
+ processor ([`Wav2Vec2Processor`])
441
+ The processor used for proccessing the data.
442
+ decoder_start_token_id (:obj: `int`)
443
+ The begin-of-sentence of the decoder.
444
+ input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
445
+ Select a strategy to pad the returned input sequences (according to the model's padding side and padding index)
446
+ among:
447
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
448
+ sequence if provided).
449
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
450
+ maximum acceptable input length for the model if that argument is not provided.
451
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
452
+ different lengths).
453
+ target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
454
+ Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
455
+ See above for details.
456
+ max_input_length (:obj:`float`, `optional`):
457
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
458
+ pad_input_to_multiple_of (:obj:`int`, `optional`):
459
+ If set will pad the input sequence to a multiple of the provided value.
460
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
461
+ 7.5 (Volta).
462
+ pad_target_to_multiple_of (:obj:`int`, `optional`):
463
+ If set will pad the target sequence to a multiple of the provided value.
464
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
465
+ 7.5 (Volta).
466
+ """
467
+
468
+ processor: Any
469
+ input_padding: Union[bool, str] = "longest"
470
+ label_padding: Union[bool, str] = "max_length"
471
+ pad_input_to_multiple_of: Optional[int] = None
472
+ pad_to_multiple_of_label: Optional[int] = None
473
+ max_input_length: Optional[float] = None
474
+ max_label_length: Optional[float] = None
475
+
476
+ def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
477
+ # split inputs and labels since they have to be of different lengths and need
478
+ # different padding methods
479
+ input_features = [{"input_values": feature["input_values"]} for feature in features]
480
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
481
+
482
+ # reformat list to dict and set to pytorch format
483
+ batch = self.processor.feature_extractor.pad(
484
+ input_features,
485
+ max_length=self.max_input_length,
486
+ padding=self.input_padding,
487
+ pad_to_multiple_of=self.pad_input_to_multiple_of,
488
+ return_tensors="np",
489
+ )
490
+
491
+ labels_batch = self.processor.tokenizer.pad(
492
+ label_features,
493
+ max_length=self.max_label_length,
494
+ padding=self.label_padding,
495
+ pad_to_multiple_of=self.pad_to_multiple_of_label,
496
+ return_tensors="np",
497
+ )
498
+
499
+ labels = labels_batch["input_ids"]
500
+ labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1))
501
+ labels = labels.filled(fill_value=-100)
502
+
503
+ batch["labels"] = labels
504
+
505
+ return batch
506
+
507
+
508
+ def get_grouped_indices(
509
+ dataset, batch_size: int, rng: Optional[List[int]] = None, mega_batch_mult: Optional[int] = None
510
+ ) -> np.array:
511
+ """
512
+ Adapted from the `get_length_grouped_indices` function in the PyTorch Trainer utils file (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_pt_utils.py#L486)
513
+ Function that returns a list of indices in which each slice of `batch_size` consecutive indices correspond to elements of similar
514
+ lengths. To do this, the indices are:
515
+
516
+ - randomly permuted (if a JAX rng is specified)
517
+ - grouped in mega-batches of size `mega_batch_mult * batch_size`
518
+ - sorted by length in each mega-batch
519
+
520
+ The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
521
+ maximum length placed first, so that an OOM happens sooner rather than later.
522
+ """
523
+ lengths = dataset["input_length"]
524
+
525
+ # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
526
+ if mega_batch_mult is None:
527
+ mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
528
+ # Just in case, for tiny datasets
529
+ if mega_batch_mult == 0:
530
+ mega_batch_mult = 1
531
+
532
+ # We need to use JAX for the random permutation as the PRNG key will be set based on the seed outside of the sampler.
533
+ num_samples = len(lengths)
534
+ indices = jax.random.permutation(rng, np.arange(num_samples)) if rng is not None else np.arange(num_samples)
535
+
536
+ megabatch_size = mega_batch_mult * batch_size
537
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
538
+ megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches]
539
+
540
+ # The rest is to get the biggest batch first.
541
+ # Since each megabatch is sorted by descending length, the longest element is the first
542
+ megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
543
+ max_idx = np.argmax(megabatch_maximums).item()
544
+ # Switch to put the longest batch in first position
545
+ # (note that this is different to the PT grouped sampler in which we only put the longest element in the first position, and not its batch)
546
+ megabatches[0], megabatches[max_idx] = megabatches[max_idx], megabatches[0]
547
+
548
+ megabatches = np.array([i for megabatch in megabatches for i in megabatch])
549
+
550
+ return megabatches
551
+
552
+
553
+ def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray:
554
+ """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
555
+ the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned."""
556
+ num_samples = len(samples_idx)
557
+ if drop_last:
558
+ samples_to_remove = num_samples % batch_size
559
+ if samples_to_remove != 0:
560
+ samples_idx = samples_idx[:-samples_to_remove]
561
+ sections_split = num_samples // batch_size
562
+ samples_idx = samples_idx.reshape((sections_split, batch_size))
563
+ else:
564
+ sections_split = math.ceil(num_samples / batch_size)
565
+ samples_idx = np.array_split(samples_idx, sections_split)
566
+ return samples_idx
567
+
568
+
569
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
570
+ summary_writer.scalar("train_time", train_time, step)
571
+
572
+ train_metrics = get_metrics(train_metrics)
573
+ for key, vals in train_metrics.items():
574
+ tag = f"train_{key}"
575
+ for i, val in enumerate(vals):
576
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
577
+
578
+
579
+ def write_eval_metric(summary_writer, eval_metrics, step, pred_str=None):
580
+ for metric_name, value in eval_metrics.items():
581
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
582
+
583
+ if pred_str is not None:
584
+ # write output actual predictions for debugging
585
+ summary_writer.text("eval_predictions", "\n".join(pred_str), step)
586
+
587
+
588
+ def write_wandb_log(metrics, step, prefix=None):
589
+ if jax.process_index() == 0:
590
+ log_metrics = {}
591
+ for k, v in metrics.items():
592
+ if "layer" in k:
593
+ log_metrics[f"{k}/"] = v
594
+ elif prefix is not None:
595
+ log_metrics[f"{prefix}/{k}"] = v
596
+ else:
597
+ log_metrics[k] = v
598
+ wandb.log(log_metrics, step)
599
+
600
+
601
+ def write_wandb_pred(pred_str, label_str, step, num_log=50, prefix="eval"):
602
+ if jax.process_index() == 0:
603
+ # convert str data to a wandb compatible format
604
+ str_data = [[label_str[i], pred_str[i]] for i in range(len(pred_str))]
605
+ # we'll log the first 50 predictions for each epoch
606
+ wandb.log(
607
+ {
608
+ f"{prefix}/step_{int(step / 1000)}k": wandb.Table(
609
+ columns=["label_str", "pred_str"], data=str_data[:num_log]
610
+ )
611
+ },
612
+ step,
613
+ )
614
+
615
+
616
+ def create_learning_rate_fn(
617
+ num_train_steps: int, num_warmup_steps: int, learning_rate: float
618
+ ) -> Callable[[int], jnp.array]:
619
+ """Returns a linear warmup, linear_decay learning rate function."""
620
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
621
+ decay_fn = optax.linear_schedule(
622
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
623
+ )
624
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
625
+ return schedule_fn
626
+
627
+
628
+ def ctc_loss(
629
+ logits,
630
+ logits_attention_mask,
631
+ labels,
632
+ blank_id,
633
+ loss_reduction="mean",
634
+ output_emission_dict=False,
635
+ log_epsilon=-100000.0,
636
+ ):
637
+ """Computes CTC loss.
638
+ This function performs forward computation over an FSA with `N * 2` states
639
+ where `N` is the max number of labels. The states are split into two groups:
640
+ Phi states and emission states. a phi-state accepts repetition of
641
+ phi (blank)-symbols and transits to emission state when the correct label is
642
+ observed. An emission state accepts repetition of the label and transits to
643
+ the next phi states at any time (so called epsilon-transition).
644
+ Below, `B` denotes the batch size, `T` denotes the time steps in `logits`,
645
+ and `N` denotes the time steps in `labels`.
646
+ Args:
647
+ logits: (B, T, K)-array containing log-probabilities of each class.
648
+ logitpaddings: (B, T)-array. Padding indicators for `logits`.
649
+ labels: (B, N)-array containing reference integer labels.
650
+ labelpaddings: (B, N)-array. Padding indicators for `labels`. Currently,
651
+ `labels` must be right-padded, i.e. each row of `labelpaddings` must be
652
+ repetition of zeroes, followed by repetition of ones.
653
+ blank_id: Id for blank token.
654
+ loss_reduction: one of "mean", "sum", "default"
655
+ - "none": no reduction is applied.
656
+ - "mean": output loss will be divided by target lengths and then the
657
+ mean over the batch is taken.
658
+ - "sum": output loss are summed over batch
659
+ output_emission_dict: whether to output additional information about the emission probs
660
+ Returns:
661
+ A pair of `(per_seq_loss, aux)`.
662
+ per_seq_loss:
663
+ (B,)-array containing loss values for each sequence in the batch.
664
+ aux: Dictionary containing interim variables used for computing losses.
665
+ aux['logalpha_phi']: (T, B, N+1)-array. Log-forward-probabilities of each
666
+ phi-state corresponding to the n-th label.
667
+ aux['logalpha_emit']: (T, B, N)-array. Log-forward-probabilities of each
668
+ emission-state corresponding to the n-th label.
669
+ aux['logprobs_phi']: (T, B, 1)-array. Probability of the phi-symbol
670
+ corresponding to each time frame.
671
+ aux['logprobs_emit']: (T, B, N)-array. Probability of the n-th label
672
+ corresponding to each time frame.
673
+ """
674
+ # label paddings are indicated by -100
675
+ labelpaddings = labels < 0
676
+ # logit paddings are the inverse of attention_mask
677
+ logitpaddings = ~logits_attention_mask
678
+
679
+ # Copied from https://github.com/tensorflow/lingvo/blob/master/lingvo/jax/layers/ctc_objectives.py
680
+ batchsize, unused_maxinputlen, num_classes = logits.shape
681
+ batchsize_, maxlabellen = labels.shape
682
+
683
+ logprobs = jax.nn.log_softmax(logits)
684
+ labellens = maxlabellen - jnp.sum(labelpaddings, axis=1).astype(jnp.int32)
685
+
686
+ # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1].
687
+ repeat = (labels[:, :-1] == labels[:, 1:]).astype(jnp.float32)
688
+ repeat = jnp.pad(repeat, ((0, 0), (0, 1)))
689
+
690
+ logprobs_phi = logprobs[:, :, blank_id : blank_id + 1] # [B, T, 1]
691
+ logprobs_phi = jnp.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1]
692
+
693
+ one_hot = jax.nn.one_hot(labels, num_classes=num_classes) # [B, N, K]
694
+ logprobs_emit = jnp.einsum("btk,bnk->btn", logprobs, one_hot)
695
+ logprobs_emit = jnp.transpose(logprobs_emit, (1, 0, 2)) # [T, B, N]
696
+
697
+ logalpha_phi_init = jnp.ones((batchsize, maxlabellen + 1)) * log_epsilon # [B, N]
698
+ logalpha_phi_init = logalpha_phi_init.at[:, 0].set(0.0)
699
+ logalpha_emit_init = jnp.ones((batchsize, maxlabellen)) * log_epsilon # [B, N]
700
+
701
+ def loop_body(prev, x):
702
+ prev_phi, prev_emit = prev
703
+ # emit-to-phi epsilon transition, except if the next label is repetition
704
+ prev_phi_orig = prev_phi
705
+ prev_phi = prev_phi.at[:, 1:].set(jnp.logaddexp(prev_phi[:, 1:], prev_emit + log_epsilon * repeat))
706
+
707
+ logprob_emit, logprob_phi, pad = x
708
+
709
+ # phi-to-emit transition
710
+ next_emit = jnp.logaddexp(prev_phi[:, :-1] + logprob_emit, prev_emit + logprob_emit)
711
+ # self-loop transition
712
+ next_phi = prev_phi + logprob_phi
713
+ # emit-to-phi blank transition only when the next label is repetition
714
+ next_phi = next_phi.at[:, 1:].set(
715
+ jnp.logaddexp(next_phi[:, 1:], prev_emit + logprob_phi + log_epsilon * (1.0 - repeat))
716
+ )
717
+
718
+ pad = pad.reshape((batchsize, 1))
719
+ next_emit = pad * prev_emit + (1.0 - pad) * next_emit
720
+ next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi
721
+
722
+ return (next_phi, next_emit), (next_phi, next_emit)
723
+
724
+ xs = (logprobs_emit, logprobs_phi, logitpaddings.transpose((1, 0)))
725
+ _, (logalpha_phi, logalpha_emit) = jax.lax.scan(loop_body, (logalpha_phi_init, logalpha_emit_init), xs)
726
+
727
+ # last row needs to be updated with the last epsilon transition
728
+ logalpha_phi_last = logalpha_phi[-1].at[:, 1:].set(jnp.logaddexp(logalpha_phi[-1, :, 1:], logalpha_emit[-1]))
729
+ logalpha_phi = logalpha_phi.at[-1].set(logalpha_phi_last)
730
+
731
+ # extract per_seq_loss
732
+ one_hot = jax.nn.one_hot(labellens, num_classes=maxlabellen + 1) # [B, N+1]
733
+ per_seq_loss = -jnp.einsum("bn,bn->b", logalpha_phi_last, one_hot)
734
+
735
+ if loss_reduction == "mean":
736
+ target_lengths = labelpaddings.shape[-1] - labelpaddings.sum(axis=-1)
737
+ loss = (per_seq_loss / target_lengths).mean()
738
+ elif loss_reduction == "sum":
739
+ loss = per_seq_loss.sum()
740
+ else:
741
+ loss = per_seq_loss
742
+
743
+ if not output_emission_dict:
744
+ return loss
745
+
746
+ return loss, {
747
+ "logalpha_phi": logalpha_phi,
748
+ "logalpha_emit": logalpha_emit,
749
+ "logprobs_phi": logprobs_phi,
750
+ "logprobs_emit": logprobs_emit,
751
+ }
752
+
753
+
754
+ def make_dataset(data_args, seed=42):
755
+ # Pre-processing dataset
756
+ import re
757
+
758
+ def map_nst(entry):
759
+ text = entry["text"].lower()
760
+ text = text.replace("(...vær stille under dette opptaket...)", "")
761
+ text = re.sub('[áàâ]', 'a', text)
762
+ text = re.sub('[ä]', 'æ', text)
763
+ text = re.sub('[éèëê]', 'e', text)
764
+ text = re.sub('[íìïî]', 'i', text)
765
+ text = re.sub('[óòöô]', 'o', text)
766
+ text = re.sub('[ö]', 'ø', text)
767
+ text = re.sub('[ç]', 'c', text)
768
+ text = re.sub('[úùüû]', 'u', text)
769
+ # text = re.sub('\\(?=(Punktum|Komma|Utropstegn|Spørsmålstegn))', ' ', text)
770
+ text = re.sub('\s+', ' ', text)
771
+ return {"text": text}
772
+
773
+ def filter_nst(entry):
774
+ if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)):
775
+ return False # Too short
776
+ if re.match(entry["type"], "pIW|CA"):
777
+ return False # Spelling out words
778
+ return True
779
+
780
+ def filter_npsc(entry):
781
+ # False if there are digits in the text
782
+ if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)):
783
+ return False # Too short
784
+ if re.search("\d", entry["text"]):
785
+ return False
786
+ return True
787
+
788
+ def map_npsc(entry):
789
+ batch = {"text": entry["text"].lower()}
790
+ batch["text"] = re.sub('[áàâ]', 'a', batch["text"])
791
+ batch["text"] = re.sub('[ä]', 'æ', batch["text"])
792
+ batch["text"] = re.sub('[éèëê]', 'e', batch["text"])
793
+ batch["text"] = re.sub('[íìïî]', 'i', batch["text"])
794
+ batch["text"] = re.sub('[óòöô]', 'o', batch["text"])
795
+ batch["text"] = re.sub('[ö]', 'ø', batch["text"])
796
+ batch["text"] = re.sub('[ç]', 'c', batch["text"])
797
+ batch["text"] = re.sub('[úùüû]', 'u', batch["text"])
798
+ batch["text"] = re.sub('\s', ' ', batch["text"])
799
+ batch["text"] = re.sub('<ee>', 'eee', batch["text"])
800
+ batch["text"] = re.sub('<qq>', 'qqq', batch["text"])
801
+ batch["text"] = re.sub('<mm>', 'mmm', batch["text"])
802
+ batch["text"] = re.sub('<inaudible>', 'xxx', batch["text"])
803
+ # batch["text"] = re.sub('<inaudible>', '?', batch["text"])
804
+ if "<" in batch["text"]:
805
+ raise ValueError(batch["text"])
806
+ return batch
807
+
808
+ nst = datasets.load_dataset("NbAiLab/NST", "no-close")
809
+ npsc = datasets.load_dataset("NbAiLab/NPSC", "16K_mp3")
810
+ # TODO NST_hesitate
811
+
812
+ split = len(npsc[data_args.train_split_name]) / (len(npsc[data_args.train_split_name]) + len(npsc[data_args.eval_split_name])) # Use same train/val ratio as NPSC
813
+ nst_train = nst[data_args.train_split_name].train_test_split(train_size=split, seed=seed)
814
+ nst[data_args.train_split_name] = nst_train["train"]
815
+ nst[data_args.eval_split_name] = nst_train["test"]
816
+
817
+ nst = nst.filter(filter_nst).map(
818
+ map_nst,
819
+ num_proc=data_args.preprocessing_num_workers,
820
+ desc="filtering NST",
821
+ ).shuffle(seed=seed)
822
+ npsc = npsc.filter(filter_npsc).map(
823
+ map_npsc,
824
+ num_proc=data_args.preprocessing_num_workers,
825
+ desc="filtering NPSC",
826
+ ).shuffle(seed=seed)
827
+
828
+ npsc_base = npsc.remove_columns([col for col in npsc[data_args.train_split_name].column_names if col not in ["text", "audio"]])
829
+ nst_base = nst.remove_columns([col for col in nst[data_args.train_split_name].column_names if col not in ["text", "audio"]])
830
+
831
+ combined = {}
832
+ for split in data_args.train_split_name, data_args.eval_split_name, data_args.test_split_name:
833
+ probs = np.array([len(nst_base[split]), len(npsc_base[split])]) # Weight by number of examples
834
+ probs = (probs / probs.sum()).tolist()
835
+ comb = datasets.interleave_datasets([nst_base[split], npsc_base[split]], probabilities=probs, seed=seed)
836
+ combined[split] = comb
837
+
838
+ return datasets.DatasetDict(**combined)
839
+
840
+ def main():
841
+ # 1. Parse input arguments
842
+ # See all possible arguments in src/transformers/training_args.py
843
+ # or by passing the --help flag to this script.
844
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
845
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FlaxTrainingArguments))
846
+
847
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
848
+ # If we pass only one argument to the script and it's the path to a json file,
849
+ # let's parse it to get our arguments.
850
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
851
+ else:
852
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
853
+
854
+ # 2. Setup logging
855
+ # Make one log on every process with the configuration for debugging.
856
+ logging.basicConfig(
857
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
858
+ datefmt="%m/%d/%Y %H:%M:%S",
859
+ handlers=[logging.StreamHandler(sys.stdout)],
860
+ )
861
+ # Set the verbosity to info of the Transformers logger.
862
+ # We only want one process per machine to log things on the screen.
863
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
864
+ if jax.process_index() == 0:
865
+ datasets.utils.logging.set_verbosity_warning()
866
+ transformers.utils.logging.set_verbosity_info()
867
+ else:
868
+ datasets.utils.logging.set_verbosity_error()
869
+ transformers.utils.logging.set_verbosity_error()
870
+
871
+ # Set up wandb run
872
+ if jax.process_index() == 0:
873
+ wandb.init(project=data_args.wandb_project, name=data_args.wandb_name, job_type=data_args.wandb_job_type)
874
+
875
+ logger.info("Training/evaluation parameters %s", training_args)
876
+
877
+ # Set the default TPU matmul precision and display the number of devices
878
+ jax.config.update("jax_default_matmul_precision", training_args.matmul_precision)
879
+ logger.info(f"JAX devices: {jax.device_count()}, matmul precision: {training_args.matmul_precision}")
880
+
881
+ # 4. Load dataset
882
+
883
+ set_seed(training_args.seed)
884
+ raw_datasets = make_dataset(data_args, seed=training_args.seed)
885
+
886
+ # raw_datasets = DatasetDict()
887
+
888
+ # if training_args.do_train:
889
+ # raw_datasets[data_args.train_split_name] = load_dataset(
890
+ # data_args.dataset_name,
891
+ # data_args.dataset_config_name,
892
+ # split=data_args.train_split_name,
893
+ # cache_dir=data_args.dataset_cache_dir,
894
+ # use_auth_token=True if model_args.use_auth_token else None,
895
+ # )
896
+
897
+ # if training_args.do_eval:
898
+ # raw_datasets[data_args.eval_split_name] = load_dataset(
899
+ # data_args.dataset_name,
900
+ # data_args.dataset_config_name,
901
+ # split=data_args.eval_split_name,
902
+ # cache_dir=data_args.dataset_cache_dir,
903
+ # use_auth_token=True if model_args.use_auth_token else None,
904
+ # )
905
+
906
+ # if training_args.do_predict:
907
+ # test_split = data_args.test_split_name.split("+")
908
+ # for split in test_split:
909
+ # raw_datasets[split] = load_dataset(
910
+ # data_args.dataset_name,
911
+ # data_args.dataset_config_name,
912
+ # split=split,
913
+ # cache_dir=data_args.dataset_cache_dir,
914
+ # use_auth_token=True if model_args.use_auth_token else None,
915
+ # )
916
+
917
+ if not training_args.do_train and not training_args.do_eval and not training_args.do_predict:
918
+ raise ValueError(
919
+ "Cannot not train, not do evaluation and not do prediction. At least one of "
920
+ "training, evaluation or prediction has to be done."
921
+ )
922
+
923
+ # if not training, there is no need to run multiple epochs
924
+ if not training_args.do_train:
925
+ training_args.num_train_epochs = 1
926
+
927
+ if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
928
+ raise ValueError(
929
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
930
+ "Make sure to set `--audio_column_name` to the correct audio column - one of "
931
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
932
+ )
933
+
934
+ if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names:
935
+ raise ValueError(
936
+ f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
937
+ "Make sure to set `--text_column_name` to the correct text column - one of "
938
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
939
+ )
940
+
941
+ # 5. Load pretrained model, tokenizer, and feature extractor
942
+ #
943
+ # Distributed training:
944
+ # The .from_pretrained methods guarantee that only one local process can concurrently
945
+ config = Wav2Vec2Config.from_pretrained(
946
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
947
+ cache_dir=model_args.cache_dir,
948
+ revision=model_args.model_revision,
949
+ use_auth_token=True if model_args.use_auth_token else None,
950
+ )
951
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
952
+ model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
953
+ cache_dir=model_args.cache_dir,
954
+ revision=model_args.model_revision,
955
+ use_auth_token=True if model_args.use_auth_token else None,
956
+ )
957
+ tokenizer = AutoTokenizer.from_pretrained(
958
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
959
+ cache_dir=model_args.cache_dir,
960
+ revision=model_args.model_revision,
961
+ use_auth_token=True if model_args.use_auth_token else None,
962
+ )
963
+ # update config according to training args, model args, and tokenizer attributes
964
+ config.update(
965
+ {
966
+ "feat_proj_dropout": model_args.feat_proj_dropout,
967
+ "attention_dropout": model_args.attention_dropout,
968
+ "hidden_dropout": model_args.hidden_dropout,
969
+ "final_dropout": model_args.final_dropout,
970
+ "mask_time_prob": model_args.mask_time_prob,
971
+ "mask_time_length": model_args.mask_time_length,
972
+ "mask_feature_prob": model_args.mask_feature_prob,
973
+ "mask_feature_length": model_args.mask_feature_length,
974
+ "gradient_checkpointing": training_args.gradient_checkpointing,
975
+ "layerdrop": model_args.layerdrop,
976
+ "ctc_loss_reduction": model_args.ctc_loss_reduction,
977
+ "ctc_zero_infinity": model_args.ctc_zero_infinity,
978
+ "pad_token_id": tokenizer.pad_token_id,
979
+ "vocab_size": tokenizer.vocab_size, # len(tokenizer),
980
+ "activation_dropout": model_args.activation_dropout,
981
+ }
982
+ )
983
+
984
+ if tokenizer.do_lower_case and data_args.dataset_name != "librispeech_asr":
985
+ raise ValueError(
986
+ "Setting the tokenizer attribute `do_lower_case` to `True` converts all input strings to "
987
+ "uppercase prior to tokenization. This should only be done when the tokenizer is built on an uppercased corpus,"
988
+ "i.e. for the dataset `librispeech_asr` only. If your dataset is not `librispeech_asr`, the tokenizer is mostly likely "
989
+ "built on an lowercased corpus. In this case, set `tokenizer.do_lower_case` to ``False`."
990
+ )
991
+
992
+ if training_args.precision == "full_mixed":
993
+ dtype = jnp.bfloat16
994
+ training_args.mixed_precision = True
995
+ elif training_args.precision == "half_mixed":
996
+ dtype = jnp.bfloat16
997
+ training_args.mixed_precision = False
998
+ else:
999
+ dtype = jnp.float32
1000
+ training_args.mixed_precision = False
1001
+
1002
+ try:
1003
+ model = FlaxWav2Vec2ForCTC.from_pretrained(
1004
+ model_args.model_name_or_path,
1005
+ config=config,
1006
+ dtype=dtype,
1007
+ cache_dir=model_args.cache_dir,
1008
+ revision=model_args.model_revision,
1009
+ use_auth_token=True if model_args.use_auth_token else None,
1010
+ )
1011
+ except:
1012
+ model = FlaxWav2Vec2ForCTC.from_pretrained(
1013
+ model_args.model_name_or_path,
1014
+ config=config,
1015
+ dtype=dtype,
1016
+ cache_dir=model_args.cache_dir,
1017
+ revision=model_args.model_revision,
1018
+ use_auth_token=True if model_args.use_auth_token else None,
1019
+ from_pt=True,
1020
+ )
1021
+
1022
+ # 6. Resample speech dataset ALWAYS
1023
+ raw_datasets = raw_datasets.cast_column(
1024
+ data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
1025
+ )
1026
+
1027
+ # 7. Preprocessing the datasets.
1028
+ # We need to read the audio files as arrays and tokenize the targets.
1029
+ max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
1030
+ min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate)
1031
+ max_target_length = data_args.max_label_length
1032
+ min_target_length = data_args.min_label_length
1033
+ pad_input_to_multiple_of = data_args.pad_input_to_multiple_of
1034
+ audio_column_name = data_args.audio_column_name
1035
+ num_workers = data_args.preprocessing_num_workers
1036
+ text_column_name = data_args.text_column_name
1037
+ model_input_name = feature_extractor.model_input_names[0]
1038
+ do_lower_case = data_args.do_lower_case
1039
+ dataset_name = data_args.dataset_name
1040
+ chars_to_ignore = ', ? . ! - ; : " “ % ‘ ” ?'.split(" ")
1041
+ chars_to_ignore_regex = f'[{"".join(chars_to_ignore)}]'
1042
+ # gigaspeech_punctuation = {" <comma>": ",", " <period>": ".", " <questionmark>": "?", " <exclamationpoint>": "!"}
1043
+ # gigaspeech_disfluencies = ["<other>", "<sil>"]
1044
+ # swb_disfluencies = ["[noise]", "[laughter]", "[silence]", "<a_aside>", "<b_aside>", "<e_aside>", "[laughter-",
1045
+ # "[vocalized-noise]", "_1"]
1046
+ # swb_punctuations = ["{", "}", "[", "]-", "]"]
1047
+ # earnings_disfluencies = ["<crosstalk>", "<affirmative>", "<inaudible>", "inaudible", "<laugh>", "<unk>"]
1048
+ ignore_segments = ["ignore_time_segment_in_scoring", "<noise>", "<music>", "[noise]", "[laughter]", "[silence]",
1049
+ "[vocalized-noise]", "<crosstalk>", "<affirmative>", "<inaudible>", "<laugh>", "<other>", "<sil>", ""]
1050
+
1051
+ if training_args.do_train and data_args.max_train_samples is not None:
1052
+ raw_datasets[data_args.train_split_name] = raw_datasets[data_args.train_split_name].select(range(data_args.max_train_samples))
1053
+
1054
+ if training_args.do_eval and data_args.max_eval_samples is not None:
1055
+ raw_datasets[data_args.eval_split_name] = raw_datasets[data_args.eval_split_name].select(range(data_args.max_eval_samples))
1056
+
1057
+ if training_args.do_predict and data_args.max_test_samples is not None:
1058
+ raw_datasets[data_args.test_split_name] = raw_datasets[data_args.test_split_name].select(range(data_args.max_eval_samples))
1059
+
1060
+ if training_args.do_train and data_args.remove_punctuation:
1061
+
1062
+ def remove_punctuation(batch):
1063
+ batch[text_column_name] = (
1064
+ re.sub(chars_to_ignore_regex, "", batch[text_column_name]).replace("'", "").replace('"', "")
1065
+ )
1066
+
1067
+ raw_datasets[data_args.train_split_name] = raw_datasets[data_args.train_split_name].map(
1068
+ remove_punctuation,
1069
+ num_proc=data_args.preprocessing_num_workers,
1070
+ desc="removing punctuation from train split",
1071
+ )
1072
+
1073
+ # filter data where the targets are ignored in scoring
1074
+ def is_target_labels(input_str):
1075
+ return input_str.lower() not in ignore_segments
1076
+
1077
+ raw_datasets = raw_datasets.filter(
1078
+ is_target_labels,
1079
+ num_proc=num_workers,
1080
+ input_columns=[text_column_name],
1081
+ desc="filtering data where the targets are ignored in scoring",
1082
+ )
1083
+
1084
+ def prepare_dataset(batch):
1085
+ # process audio
1086
+ try:
1087
+ sample = batch[audio_column_name]
1088
+ except ValueError:
1089
+ sample = {"array": np.array([0.]), "sampling_rate": feature_extractor.sampling_rate}
1090
+ inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
1091
+ # process audio length
1092
+ batch[model_input_name] = inputs.input_values[0]
1093
+ batch["input_length"] = len(batch["input_values"])
1094
+
1095
+ # process targets
1096
+ input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
1097
+
1098
+ # if dataset_name == "google/xtreme_s":
1099
+ # # Finally, we tokenize the processed text
1100
+ # batch["labels"] = tokenizer(input_str).input_ids
1101
+ # batch["labels_length"] = len(batch["labels"])
1102
+ # return batch
1103
+
1104
+ # # Common Voice 9
1105
+ # if input_str.startswith('"') and input_str.endswith('"'):
1106
+ # # we can remove trailing quotation marks as they do not affect the transcription
1107
+ # input_str = input_str[1:-1]
1108
+ # # normalize quotation marks
1109
+ # input_str = re.sub(r'["“”]', '"', input_str)
1110
+ # # normalize apostrophes
1111
+ # input_str = re.sub(r"[’']", "'", input_str)
1112
+ # # normalize hyphens
1113
+ # input_str = re.sub(r"[—–]", "-", input_str)
1114
+ # # replace double quotation marks with single
1115
+ # input_str = input_str.replace('""', '"')
1116
+ # if dataset_name == "mozilla-foundation/common_voice_9_0" and len(input_str):
1117
+ # # for CV9, we'll normalize the text to always finish with punctuation
1118
+ # if input_str[-1] not in [".", "?", "!"]:
1119
+ # input_str = input_str + "."
1120
+
1121
+ # # TEDLIUM-3
1122
+ # # delete the <unk> token from the text and replace spaced apostrophes with un-spaced
1123
+ # input_str = input_str.replace("<unk>", "").replace(" '", "'")
1124
+
1125
+ # # GigaSpeech
1126
+ # for disfluency in gigaspeech_disfluencies:
1127
+ # input_str = input_str.replace(disfluency, "")
1128
+ # # convert spelled out punctuation to symbolic form
1129
+ # for punctuation, replacement in gigaspeech_punctuation.items():
1130
+ # input_str = input_str.replace(punctuation, replacement)
1131
+ # if dataset_name == "speechcolab/gigaspeech" and len(input_str):
1132
+ # # for GS, we'll normalize the text to always finish with punctuation
1133
+ # if input_str[-1] not in [".", "?", "!"]:
1134
+ # input_str = input_str + "."
1135
+
1136
+ # # SWB
1137
+ # for disfluency in swb_disfluencies:
1138
+ # input_str = input_str.replace(disfluency, "")
1139
+ # # remove parenthesised text (test data only)
1140
+ # input_str = re.sub("[\(].*?[\)]", "", input_str)
1141
+ # for punctuation in swb_punctuations:
1142
+ # input_str = input_str.replace(punctuation, "")
1143
+ # # replace anomalous words with their correct transcriptions
1144
+ # split_str = input_str.split("/")
1145
+ # if len(split_str) > 1:
1146
+ # input_str = " ".join(
1147
+ # [" ".join([" ".join(i.split(" ")[:-1]) for i in split_str])] + [split_str[-1].split(" ")[-1]])
1148
+
1149
+ # # Earnings 22
1150
+ # for disfluency in earnings_disfluencies:
1151
+ # input_str = input_str.replace(disfluency, "")
1152
+ # # replace mal-formatted ellipsis
1153
+ # input_str = input_str.replace("…", ".")
1154
+
1155
+ # JIWER compliance
1156
+ # remove multiple spaces
1157
+ input_str = re.sub(r"\s\s+", " ", input_str)
1158
+ # strip trailing spaces
1159
+ input_str = input_str.strip()
1160
+
1161
+ # Finally, we tokenize the processed text
1162
+ batch["labels"] = tokenizer(input_str).input_ids
1163
+ batch["labels_length"] = len(batch["labels"])
1164
+ return batch
1165
+
1166
+ vectorized_datasets = raw_datasets.map(
1167
+ prepare_dataset,
1168
+ remove_columns=next(iter(raw_datasets.values())).column_names,
1169
+ num_proc=num_workers,
1170
+ desc="preprocess dataset",
1171
+ )
1172
+
1173
+ # filter data with inputs shorter than min_input_length or longer than max_input_length
1174
+ def is_audio_in_length_range(length):
1175
+ return length > min_input_length and length < max_input_length
1176
+
1177
+ vectorized_datasets = vectorized_datasets.filter(
1178
+ is_audio_in_length_range,
1179
+ num_proc=num_workers,
1180
+ input_columns=["input_length"],
1181
+ )
1182
+
1183
+ # filter data with targets shorter than min_target_length or longer than max_target_length
1184
+ def is_labels_in_length_range(length):
1185
+ return length > min_target_length # and length < max_target_length
1186
+
1187
+ vectorized_datasets = vectorized_datasets.filter(
1188
+ is_labels_in_length_range,
1189
+ num_proc=num_workers,
1190
+ input_columns=["labels_length"],
1191
+ )
1192
+
1193
+ # for large datasets it is advised to run the preprocessing on a
1194
+ # single machine first with `args.preprocessing_only` since there will mostly likely
1195
+ # be a timeout when running the script in distributed mode.
1196
+ # In a second step `args.preprocessing_only` can then be set to `False` to load the
1197
+ # cached dataset
1198
+ if data_args.preprocessing_only:
1199
+ cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
1200
+ logger.info(f"Data preprocessing finished. Files cached at {cache}.")
1201
+ return
1202
+
1203
+ # 8. Load Metrics
1204
+ wer_metric = load_metric("wer")
1205
+ cer_metric = load_metric("cer")
1206
+
1207
+ def compute_metrics(pred_ids: List[List[int]], label_ids: List[List[int]]):
1208
+ padded_ids = np.where(np.asarray(label_ids) == -100, tokenizer.pad_token_id, np.asarray(label_ids))
1209
+
1210
+ pred_str = tokenizer.batch_decode(pred_ids)
1211
+ # we do not want to group tokens when computing the metrics
1212
+ label_str = tokenizer.batch_decode(padded_ids, group_tokens=False)
1213
+
1214
+ wer = wer_metric.compute(predictions=pred_str, references=label_str)
1215
+ cer = cer_metric.compute(predictions=pred_str, references=label_str)
1216
+
1217
+ return {"wer": wer, "cer": cer}, pred_str, label_str
1218
+
1219
+ # 9. save feature extractor, tokenizer and config
1220
+ feature_extractor.save_pretrained(training_args.output_dir)
1221
+ tokenizer.save_pretrained(training_args.output_dir)
1222
+ config.save_pretrained(training_args.output_dir)
1223
+
1224
+ processor = AutoProcessor.from_pretrained(training_args.output_dir)
1225
+
1226
+ data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding(
1227
+ processor=processor,
1228
+ input_padding="longest",
1229
+ pad_input_to_multiple_of=pad_input_to_multiple_of,
1230
+ max_label_length=data_args.max_label_length,
1231
+ )
1232
+
1233
+ # Enable tensorboard only on the master node
1234
+ has_tensorboard = is_tensorboard_available()
1235
+ if has_tensorboard and jax.process_index() == 0:
1236
+ try:
1237
+ from flax.metrics.tensorboard import SummaryWriter
1238
+
1239
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
1240
+ except ImportError as ie:
1241
+ has_tensorboard = False
1242
+ logger.warning(
1243
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
1244
+ )
1245
+ else:
1246
+ logger.warning(
1247
+ "Unable to display metrics through TensorBoard because the package is not installed: "
1248
+ "Please run `pip install tensorboard` to enable."
1249
+ )
1250
+
1251
+ # 10. Handle the repository creation
1252
+ if training_args.push_to_hub:
1253
+ with open(os.path.join(training_args.output_dir, ".gitattributes"), "r+") as f:
1254
+ git_lfs_extensions = f.read()
1255
+ if "*.wandb" not in git_lfs_extensions:
1256
+ f.write("*.wandb filter=lfs diff=lfs merge=lfs -text")
1257
+ if training_args.hub_model_id is None:
1258
+ repo_name = get_full_repo_name(
1259
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
1260
+ )
1261
+ else:
1262
+ repo_name = training_args.hub_model_id
1263
+ repo = Repository(training_args.output_dir, clone_from=repo_name)
1264
+
1265
+ # 11. Initialize our training
1266
+ rng = jax.random.PRNGKey(training_args.seed)
1267
+ rng, dropout_rng = jax.random.split(rng)
1268
+
1269
+ # Store some constants
1270
+ max_steps = int(training_args.max_steps)
1271
+ gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
1272
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
1273
+ batch_size_per_update = train_batch_size * gradient_accumulation_steps
1274
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
1275
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
1276
+ to_dtype = to_bf16 if training_args.mixed_precision else to_fp32
1277
+
1278
+ if training_args.do_train:
1279
+ num_train_samples = len(vectorized_datasets[data_args.train_split_name])
1280
+ steps_per_epoch = num_train_samples // batch_size_per_update
1281
+ if max_steps > 0:
1282
+ num_epochs = -(training_args.max_steps // -steps_per_epoch)
1283
+ total_train_steps = max_steps
1284
+ else:
1285
+ num_epochs = int(training_args.num_train_epochs)
1286
+ total_train_steps = steps_per_epoch * num_epochs
1287
+
1288
+ # Create learning rate schedule
1289
+ # Create learning rate schedule
1290
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
1291
+ total_train_steps,
1292
+ training_args.warmup_steps,
1293
+ training_args.learning_rate,
1294
+ )
1295
+
1296
+ # We use Optax's "masking" functionality to not apply weight decay
1297
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
1298
+ # mask boolean with the same structure as the parameters.
1299
+ # The mask is True for parameters that should be decayed.
1300
+ # Note that this mask is specifically adapted for FlaxWav2Vec2 and FlaxBart.
1301
+ # For FlaxT5, one should correct the layer norm parameter naming
1302
+ # accordingly - see `run_t5_mlm_flax.py` e.g.
1303
+ def decay_mask_fn(params):
1304
+ flat_params = traverse_util.flatten_dict(params)
1305
+ layer_norm_params = [
1306
+ (name, "scale")
1307
+ for name in ["layer_norm", "self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
1308
+ ]
1309
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
1310
+ return traverse_util.unflatten_dict(flat_mask)
1311
+
1312
+ if training_args.adafactor:
1313
+ # Create Adafactor optimizer
1314
+ optim = optax.adafactor(
1315
+ learning_rate=linear_decay_lr_schedule_fn,
1316
+ dtype_momentum=jnp.bfloat16 if training_args.mixed_precision else jnp.float32,
1317
+ weight_decay_rate=training_args.weight_decay,
1318
+ weight_decay_mask=decay_mask_fn,
1319
+ )
1320
+ else:
1321
+ # Create AdamW optimizer
1322
+ optim = optax.adamw(
1323
+ learning_rate=linear_decay_lr_schedule_fn,
1324
+ b1=training_args.adam_beta1,
1325
+ b2=training_args.adam_beta2,
1326
+ eps=training_args.adam_epsilon,
1327
+ weight_decay=training_args.weight_decay,
1328
+ mask=decay_mask_fn,
1329
+ )
1330
+
1331
+ # Optax MultiSteps for gradient accumulation. We'll only call this optimizer transformation if gradient accumulation is required (i.e. gradient accumulation steps > 1)
1332
+ if training_args.multisteps and gradient_accumulation_steps > 1:
1333
+ optim = optax.MultiSteps(optim, gradient_accumulation_steps, use_grad_mean=False)
1334
+ else:
1335
+ num_epochs = 0
1336
+ total_train_steps = 0
1337
+ num_train_samples = 0
1338
+ optim = None
1339
+
1340
+ # Setup train state
1341
+ state = MixedPrecisionTrainState.create(
1342
+ apply_fn=model.__call__,
1343
+ get_attention_mask_fn=model._get_feature_vector_attention_mask,
1344
+ params=model.params,
1345
+ tx=optim,
1346
+ to_dtype=to_dtype,
1347
+ dropout_rng=dropout_rng,
1348
+ max_grad_norm=training_args.max_grad_norm,
1349
+ )
1350
+
1351
+ # Replicate the train state on each device
1352
+ state = state.replicate()
1353
+ blank_id = model.config.pad_token_id
1354
+
1355
+ # Define gradient update step fn
1356
+ def train_step(state, batch):
1357
+ # only one single rng per grad step, with or without accumulation, as the graph should be identical over one effective training batch
1358
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
1359
+
1360
+ def compute_loss(params, minibatch):
1361
+ labels = minibatch.pop("labels")
1362
+ logits = state.apply_fn(
1363
+ **minibatch,
1364
+ params=params,
1365
+ dropout_rng=dropout_rng,
1366
+ freeze_feature_encoder=model_args.freeze_feature_encoder,
1367
+ train=True,
1368
+ )[0]
1369
+ logits_mask = state.get_attention_mask_fn(logits.shape[1], batch["attention_mask"])
1370
+ loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean")
1371
+
1372
+ return loss
1373
+
1374
+ grad_fn = jax.value_and_grad(compute_loss)
1375
+
1376
+ if gradient_accumulation_steps == 1 or training_args.multisteps:
1377
+ loss, grad = grad_fn(to_dtype(state.params), batch)
1378
+
1379
+ # Custom gradient accumulation
1380
+ else:
1381
+ # add a first dimension over gradient_accumulation_steps for minibatch slices
1382
+ batch = jax.tree_util.tree_map(
1383
+ lambda x: x.reshape(
1384
+ gradient_accumulation_steps, training_args.per_device_train_batch_size, *x.shape[1::]
1385
+ ),
1386
+ batch,
1387
+ )
1388
+
1389
+ def accum_minibatch_step(accum_grad, minibatch):
1390
+ # compute loss, num labels and grad over minibatch and accumulate
1391
+ loss, grad = grad_fn(to_dtype(state.params), minibatch)
1392
+ return jax.tree_util.tree_map(jnp.add, accum_grad, grad), loss
1393
+
1394
+ # create an initial state for accumulating losses, num labels and gradients
1395
+ init_grad = jax.tree_util.tree_map(jnp.zeros_like, to_dtype(state.params))
1396
+ # loop accum minibatch step over the number of gradient accumulation steps
1397
+ grad, loss = jax.lax.scan(accum_minibatch_step, init_grad, batch)
1398
+
1399
+ # update state
1400
+ new_state = state.apply_gradients(
1401
+ grads=grad,
1402
+ dropout_rng=new_dropout_rng,
1403
+ to_dtype=to_dtype,
1404
+ )
1405
+
1406
+ # compute gradient norms over all layers and globally for detailed monitoring
1407
+ layer_grad_norm = jax.tree_util.tree_map(jnp.linalg.norm, grad)
1408
+ logs = {
1409
+ "layer_grad_norm": layer_grad_norm,
1410
+ "grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm)),
1411
+ }
1412
+
1413
+ # compute parameter norms over all layers and globally for detailed monitoring
1414
+ layer_param_norm = jax.tree_util.tree_map(jnp.linalg.norm, new_state.params)
1415
+ logs["layer_param_norm"] = layer_param_norm
1416
+ logs["param_norm"] = jnp.linalg.norm(jax.tree_util.tree_leaves(layer_param_norm))
1417
+
1418
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
1419
+ metrics.update(logs)
1420
+
1421
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1422
+ # metrics = to_fp32(metrics)
1423
+
1424
+ return new_state, metrics
1425
+
1426
+ # Define eval fn
1427
+ def eval_step(params, batch):
1428
+ labels = batch.pop("labels")
1429
+ logits = model(**batch, params=params, train=False)[0]
1430
+
1431
+ logits_mask = model._get_feature_vector_attention_mask(logits.shape[1], batch["attention_mask"])
1432
+ loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean")
1433
+
1434
+ pred_ids = jnp.argmax(logits, axis=-1)
1435
+
1436
+ # summarize metrics
1437
+ metrics = {"loss": loss}
1438
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1439
+ # metrics = to_fp32(metrics)
1440
+ return metrics, pred_ids
1441
+
1442
+ # Create parallel version of the train and eval step
1443
+ if training_args.do_train:
1444
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
1445
+
1446
+ if training_args.do_eval:
1447
+ p_eval_step = jax.pmap(eval_step, "batch")
1448
+
1449
+ def run_evaluation(step):
1450
+ if training_args.do_eval:
1451
+ # ======================== Evaluating ==============================
1452
+ eval_metrics = []
1453
+ eval_preds = []
1454
+ eval_labels = []
1455
+
1456
+ # Generate eval set by sequentially sampling indices from the eval dataset and grouping by length
1457
+ eval_samples_idx = get_grouped_indices(vectorized_datasets[data_args.eval_split_name], eval_batch_size)
1458
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
1459
+
1460
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
1461
+ samples = [vectorized_datasets[data_args.eval_split_name][int(idx)] for idx in batch_idx]
1462
+ batch = data_collator(samples)
1463
+ labels = batch["labels"]
1464
+
1465
+ metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1466
+ eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1])))
1467
+ eval_metrics.append(metrics)
1468
+
1469
+ eval_labels.extend(labels)
1470
+
1471
+ # normalize eval metrics
1472
+ eval_metrics = get_metrics(eval_metrics)
1473
+ eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
1474
+ eval_metrics = to_fp32(eval_metrics)
1475
+
1476
+ # always run compute metrics
1477
+ error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels)
1478
+ eval_metrics.update(error_rate_metric)
1479
+ error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()])
1480
+
1481
+ # Print metrics and update progress bar
1482
+ desc = f"Step... ({step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})"
1483
+ epochs.write(desc)
1484
+ epochs.desc = desc
1485
+
1486
+ # Save metrics
1487
+ write_wandb_log(eval_metrics, step, prefix="eval")
1488
+ write_wandb_pred(pred_str, label_str, step)
1489
+ # if has_tensorboard and jax.process_index() == 0:
1490
+ # write_eval_metric(summary_writer, eval_metrics, step, pred_str=pred_str)
1491
+
1492
+ def save_checkpoint(step):
1493
+ # save and push checkpoint to the hub
1494
+ if jax.process_index() == 0:
1495
+ params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
1496
+ model.save_pretrained(training_args.output_dir, params=params)
1497
+ tokenizer.save_pretrained(training_args.output_dir)
1498
+ if training_args.push_to_hub:
1499
+ repo.push_to_hub(commit_message=f"{wandb.run.id}: saving weights and logs of step {int(step / 1000)}k", blocking=False)
1500
+
1501
+ skip_epochs = data_args.skip_steps // (num_train_samples // batch_size_per_update)
1502
+ logger.info("***** Running training *****")
1503
+ logger.info(f" Num examples = {num_train_samples}")
1504
+ logger.info(f" Num Epochs = {num_epochs}")
1505
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
1506
+ logger.info(f" Num gradient accumulation steps = {gradient_accumulation_steps}")
1507
+ logger.info(f" Total train batch size (w. parallel & distributed) = {batch_size_per_update}")
1508
+ logger.info(f" Total optimization steps = {total_train_steps}")
1509
+ logger.info(f" Gradient checkpointing: {config.gradient_checkpointing}")
1510
+ logger.info(f" Use scan: {config.use_scan}")
1511
+ logger.info(f" Fuse matmuls: {config.fuse_matmuls}")
1512
+ logger.info(f" Skipping: {data_args.skip_steps} steps ({skip_epochs} epochs)")
1513
+
1514
+ train_time = cur_step = 0
1515
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
1516
+ for epoch in epochs:
1517
+ if training_args.do_train:
1518
+ # ======================== Training ================================
1519
+ train_start = time.time()
1520
+
1521
+ if epoch < skip_epochs:
1522
+ logger.info(f"Skipping epoch {epoch + 1}")
1523
+ continue
1524
+
1525
+ # Create sampling rng
1526
+ rng, input_rng = jax.random.split(rng)
1527
+
1528
+ # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length
1529
+ train_samples_idx = get_grouped_indices(vectorized_datasets[data_args.train_split_name], batch_size_per_update, input_rng)
1530
+ train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update)
1531
+
1532
+ if data_args.skip_steps > cur_step:
1533
+ logger.info(f"Skipping {data_args.skip_steps - (epoch * (num_train_samples // batch_size_per_update))} steps...")
1534
+ # Gather the indices for creating the batch and do a training step
1535
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1), 1):
1536
+ cur_step = epoch * (num_train_samples // batch_size_per_update) + step
1537
+ if cur_step <= data_args.skip_steps:
1538
+ continue
1539
+
1540
+ samples = [vectorized_datasets[data_args.train_split_name][int(idx)] for idx in batch_idx]
1541
+ batch = data_collator(samples)
1542
+ batch = shard(batch.data)
1543
+ try:
1544
+ state, train_metric = p_train_step(state, batch)
1545
+ except TypeError as e:
1546
+ logger.warning("Encountered following error: \n", e)
1547
+
1548
+
1549
+ if cur_step % training_args.logging_steps == 0:
1550
+ # Save metrics
1551
+ train_metric = unreplicate(train_metric)
1552
+ train_time += time.time() - train_start
1553
+ # need to upcast all device arrays to fp32 for wandb logging (jnp.bfloat16 not supported) -> do this here OR in train_step
1554
+ write_wandb_log(to_fp32(train_metric), cur_step, prefix=data_args.train_split_name)
1555
+ # we won't log to tensorboard for now (it is fiddly logging param and grad norms on a layer-by-layer basis)
1556
+ # if has_tensorboard and jax.process_index() == 0:
1557
+ # write_train_metric(summary_writer, train_metrics, train_time, cur_step)
1558
+
1559
+ epochs.write(
1560
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']}, Gradient Norm: {train_metric['grad_norm']})"
1561
+ )
1562
+
1563
+ if cur_step % total_train_steps == 0:
1564
+ break
1565
+
1566
+ if training_args.eval_steps and cur_step % training_args.eval_steps == 0:
1567
+ run_evaluation(cur_step)
1568
+
1569
+ if cur_step % training_args.save_steps == 0:
1570
+ save_checkpoint(cur_step)
1571
+
1572
+ if training_args.eval_steps == 0 and (epoch + 1) != num_epochs:
1573
+ # run evaluation at the end of the epoch if eval steps are not specified
1574
+ run_evaluation(cur_step)
1575
+ save_checkpoint(cur_step)
1576
+
1577
+ if training_args.do_train:
1578
+ save_checkpoint(cur_step)
1579
+
1580
+ cur_step = max_steps if max_steps > 0 else cur_step # set step to max steps so that eval happens in alignment with training
1581
+
1582
+ if training_args.do_eval:
1583
+ run_evaluation(cur_step)
1584
+
1585
+ # TODO: collapse 'do_predict' into the run_evaluation function
1586
+ if training_args.do_predict:
1587
+ for split in [data_args.test_split_name]:
1588
+ # ======================== Evaluating ==============================
1589
+ eval_metrics = []
1590
+ eval_preds = []
1591
+ eval_labels = []
1592
+
1593
+ # Generate eval set by sequentially sampling indices from the test dataset and grouping by length
1594
+ eval_samples_idx = get_grouped_indices(vectorized_datasets[split], eval_batch_size)
1595
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
1596
+
1597
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc=f"Predicting {split}...", position=2)):
1598
+ samples = [vectorized_datasets[split][int(idx)] for idx in batch_idx]
1599
+ batch = data_collator(samples)
1600
+ labels = batch["labels"]
1601
+
1602
+ metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1603
+ eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1])))
1604
+ eval_metrics.append(metrics)
1605
+
1606
+ eval_labels.extend(labels)
1607
+
1608
+ # normalize eval metrics
1609
+ eval_metrics = get_metrics(eval_metrics)
1610
+ eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
1611
+ eval_metrics = to_fp32(eval_metrics)
1612
+
1613
+ # always run compute metrics
1614
+ error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels)
1615
+ eval_metrics.update(error_rate_metric)
1616
+ error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()])
1617
+
1618
+ # Print metrics and update progress bar
1619
+ desc = f"Step... ({cur_step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})"
1620
+ epochs.write(desc)
1621
+ epochs.desc = desc
1622
+
1623
+ # Save metrics
1624
+ write_wandb_log(eval_metrics, cur_step, prefix=split)
1625
+ write_wandb_pred(pred_str, label_str, cur_step, prefix=split)
1626
+ # if has_tensorboard and jax.process_index() == 0:
1627
+ # write_eval_metric(summary_writer, eval_metrics, cur_step, pred_str=pred_str)
1628
+
1629
+
1630
+ if __name__ == "__main__":
1631
+ main()
wandb/run-20220805_114003-3f1x0hvu/files/config.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_version: 1
2
+
3
+ _wandb:
4
+ desc: null
5
+ value:
6
+ cli_version: 0.12.9
7
+ code_path: code/run_flax_speech_recognition_ctc.py
8
+ framework: huggingface
9
+ huggingface_version: 4.21.0
10
+ is_jupyter_run: false
11
+ is_kaggle_kernel: false
12
+ python_version: 3.8.10
13
+ start_time: 1659699604
14
+ t:
15
+ 1:
16
+ - 1
17
+ - 2
18
+ - 3
19
+ - 11
20
+ - 12
21
+ 2:
22
+ - 1
23
+ - 2
24
+ - 3
25
+ - 11
26
+ - 12
27
+ 3:
28
+ - 13
29
+ 4: 3.8.10
30
+ 5: 0.12.9
31
+ 6: 4.21.0
32
+ 8:
33
+ - 5
wandb/run-20220805_114003-3f1x0hvu/files/diff.patch ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/run.sh b/run.sh
2
+ index 9cc498e..604b808 100755
3
+ --- a/run.sh
4
+ +++ b/run.sh
5
+ @@ -5,8 +5,8 @@ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_c
6
+ --output_dir="./" \
7
+ --overwrite_output_dir \
8
+ --num_train_epochs="40" \
9
+ - --per_device_train_batch_size="2" \
10
+ - --per_device_eval_batch_size="2" \
11
+ + --per_device_train_batch_size="4" \
12
+ + --per_device_eval_batch_size="4" \
13
+ --gradient_accumulation_steps="1" \
14
+ --precision="full_mixed" \
15
+ --matmul_precision="bfloat16" \
16
+ @@ -16,8 +16,8 @@ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_c
17
+ --length_column_name="input_length" \
18
+ --evaluation_strategy="steps" \
19
+ --text_column_name="text" \
20
+ - --save_steps="4000" \
21
+ - --eval_steps="4000" \
22
+ + --save_steps="5000" \
23
+ + --eval_steps="5000" \
24
+ --logging_steps="100" \
25
+ --layerdrop="0.041" \
26
+ --attention_dropout="0.094" \
27
+ diff --git a/wandb/debug-internal.log b/wandb/debug-internal.log
28
+ index 50a0b69..4934fdf 120000
29
+ --- a/wandb/debug-internal.log
30
+ +++ b/wandb/debug-internal.log
31
+ @@ -1 +1 @@
32
+ -run-20220803_091109-yit1e59z/logs/debug-internal.log
33
+
34
+ +run-20220805_114003-3f1x0hvu/logs/debug-internal.log
35
+
36
+ diff --git a/wandb/debug.log b/wandb/debug.log
37
+ index 746223d..3328f24 120000
38
+ --- a/wandb/debug.log
39
+ +++ b/wandb/debug.log
40
+ @@ -1 +1 @@
41
+ -run-20220803_091109-yit1e59z/logs/debug.log
42
+
43
+ +run-20220805_114003-3f1x0hvu/logs/debug.log
44
+
45
+ diff --git a/wandb/latest-run b/wandb/latest-run
46
+ index be58b40..d627a18 120000
47
+ --- a/wandb/latest-run
48
+ +++ b/wandb/latest-run
49
+ @@ -1 +1 @@
50
+ -run-20220803_091109-yit1e59z
51
+
52
+ +run-20220805_114003-3f1x0hvu
53
+
wandb/run-20220805_114003-3f1x0hvu/files/output.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5174df6a0b92bb2375fae8ef696b8a151681c1be2513b177260019f859192dd2
3
+ size 224929
wandb/run-20220805_114003-3f1x0hvu/files/requirements.txt ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.0.0
2
+ aiohttp==3.8.1
3
+ aiosignal==1.2.0
4
+ appdirs==1.4.4
5
+ astunparse==1.6.3
6
+ async-timeout==4.0.2
7
+ attrs==21.4.0
8
+ audioread==2.1.9
9
+ backcall==0.2.0
10
+ cachetools==4.2.4
11
+ certifi==2021.10.8
12
+ cffi==1.15.1
13
+ charset-normalizer==2.0.10
14
+ chex==0.1.3
15
+ click==8.0.3
16
+ cloud-tpu-client==0.10
17
+ cloud-tpu-profiler==2.4.0
18
+ clu==0.0.6
19
+ colorama==0.4.5
20
+ commonmark==0.9.1
21
+ configparser==5.2.0
22
+ contextlib2==21.6.0
23
+ cycler==0.11.0
24
+ datasets==2.4.0
25
+ decorator==5.1.0
26
+ dill==0.3.4
27
+ dm-tree==0.1.6
28
+ docker-pycreds==0.4.0
29
+ etils==0.6.0
30
+ exceptiongroup==1.0.0rc8
31
+ filelock==3.4.2
32
+ flatbuffers==2.0
33
+ flax==0.5.3
34
+ fonttools==4.28.5
35
+ frozenlist==1.2.0
36
+ fsspec==2021.11.1
37
+ future==0.18.2
38
+ gast==0.4.0
39
+ gitdb==4.0.9
40
+ gitpython==3.1.26
41
+ google-api-core==1.31.5
42
+ google-api-python-client==1.8.0
43
+ google-auth-httplib2==0.1.0
44
+ google-auth-oauthlib==0.4.6
45
+ google-auth==2.3.3
46
+ google-pasta==0.2.0
47
+ googleapis-common-protos==1.54.0
48
+ grpcio==1.43.0
49
+ h5py==3.6.0
50
+ httplib2==0.20.2
51
+ huggingface-hub==0.2.1
52
+ hypothesis==6.53.0
53
+ idna==3.3
54
+ importlib-metadata==4.10.0
55
+ importlib-resources==5.4.0
56
+ ipython==7.31.0
57
+ jax==0.3.15
58
+ jaxlib==0.3.15
59
+ jedi==0.18.1
60
+ jiwer==2.3.0
61
+ joblib==1.1.0
62
+ keras-preprocessing==1.1.2
63
+ keras==2.7.0
64
+ kiwisolver==1.3.2
65
+ libclang==12.0.0
66
+ librosa==0.9.2
67
+ libtpu-nightly==0.1.dev20220722
68
+ llvmlite==0.39.0
69
+ markdown==3.3.6
70
+ matplotlib-inline==0.1.3
71
+ matplotlib==3.5.1
72
+ ml-collections==0.1.0
73
+ msgpack==1.0.3
74
+ multidict==5.2.0
75
+ multiprocess==0.70.12.2
76
+ numba==0.56.0
77
+ numpy==1.22.0
78
+ oauth2client==4.1.3
79
+ oauthlib==3.1.1
80
+ opt-einsum==3.3.0
81
+ optax==0.1.3
82
+ packaging==21.3
83
+ pandas==1.3.5
84
+ parso==0.8.3
85
+ pathtools==0.1.2
86
+ pexpect==4.8.0
87
+ pickleshare==0.7.5
88
+ pillow==9.0.0
89
+ pip==22.2.1
90
+ pkg-resources==0.0.0
91
+ pooch==1.6.0
92
+ promise==2.3
93
+ prompt-toolkit==3.0.24
94
+ protobuf==3.19.1
95
+ psutil==5.9.0
96
+ ptyprocess==0.7.0
97
+ pyarrow==6.0.1
98
+ pyasn1-modules==0.2.8
99
+ pyasn1==0.4.8
100
+ pycparser==2.21
101
+ pyctcdecode==0.4.0
102
+ pygments==2.11.1
103
+ pygtrie==2.5.0
104
+ pyparsing==3.0.6
105
+ python-dateutil==2.8.2
106
+ python-levenshtein==0.12.2
107
+ pytz==2021.3
108
+ pyyaml==6.0
109
+ regex==2021.11.10
110
+ requests-oauthlib==1.3.0
111
+ requests==2.27.0
112
+ resampy==0.3.1
113
+ responses==0.18.0
114
+ rich==11.2.0
115
+ rsa==4.8
116
+ sacremoses==0.0.46
117
+ scikit-learn==1.1.1
118
+ scipy==1.7.3
119
+ sentry-sdk==1.5.2
120
+ setuptools==44.0.0
121
+ shortuuid==1.0.8
122
+ six==1.16.0
123
+ smmap==5.0.0
124
+ sortedcontainers==2.4.0
125
+ soundfile==0.10.3.post1
126
+ sox==1.4.1
127
+ subprocess32==3.5.4
128
+ tensorboard-data-server==0.6.1
129
+ tensorboard-plugin-wit==1.8.0
130
+ tensorboard==2.7.0
131
+ tensorflow-cpu==2.7.0
132
+ tensorflow-datasets==4.4.0
133
+ tensorflow-estimator==2.7.0
134
+ tensorflow-io-gcs-filesystem==0.23.1
135
+ tensorflow-metadata==1.5.0
136
+ tensorflow==2.7.0
137
+ tensorstore==0.1.21
138
+ termcolor==1.1.0
139
+ threadpoolctl==3.1.0
140
+ tokenizers==0.11.2
141
+ toolz==0.11.2
142
+ torch==1.12.0
143
+ torchaudio==0.12.0+cpu
144
+ tqdm==4.62.3
145
+ traitlets==5.1.1
146
+ transformers==4.21.0
147
+ typing-extensions==4.3.0
148
+ uritemplate==3.0.1
149
+ urllib3==1.26.7
150
+ wandb==0.12.9
151
+ wcwidth==0.2.5
152
+ werkzeug==2.0.2
153
+ wheel==0.37.1
154
+ wrapt==1.13.3
155
+ xxhash==2.0.2
156
+ yarl==1.7.2
157
+ yaspin==2.1.0
158
+ zipp==3.7.0
wandb/run-20220805_114003-3f1x0hvu/files/wandb-metadata.json ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.4.0-1043-gcp-x86_64-with-glibc2.29",
3
+ "python": "3.8.10",
4
+ "heartbeatAt": "2022-08-05T11:40:07.382258",
5
+ "startedAt": "2022-08-05T11:40:03.911271",
6
+ "docker": null,
7
+ "cpu_count": 96,
8
+ "cuda": null,
9
+ "args": [
10
+ "--model_name_or_path=facebook/wav2vec2-xls-r-1b",
11
+ "--hub_model_id=NbAiLab/wav2vec2-1b-npsc-nst",
12
+ "--tokenizer_name=./",
13
+ "--output_dir=./",
14
+ "--overwrite_output_dir",
15
+ "--num_train_epochs=40",
16
+ "--per_device_train_batch_size=4",
17
+ "--per_device_eval_batch_size=4",
18
+ "--gradient_accumulation_steps=1",
19
+ "--precision=full_mixed",
20
+ "--matmul_precision=bfloat16",
21
+ "--multisteps",
22
+ "--learning_rate=2e-4",
23
+ "--warmup_steps=2000",
24
+ "--length_column_name=input_length",
25
+ "--evaluation_strategy=steps",
26
+ "--text_column_name=text",
27
+ "--save_steps=5000",
28
+ "--eval_steps=5000",
29
+ "--logging_steps=100",
30
+ "--layerdrop=0.041",
31
+ "--attention_dropout=0.094",
32
+ "--activation_dropout=0.055",
33
+ "--hidden_dropout=0.047",
34
+ "--save_total_limit=5",
35
+ "--freeze_feature_encoder",
36
+ "--feat_proj_dropout=0.04",
37
+ "--mask_time_prob=0.082",
38
+ "--mask_time_length=10",
39
+ "--mask_feature_prob=0.25",
40
+ "--mask_feature_length=64",
41
+ "--gradient_checkpointing",
42
+ "--min_duration_in_seconds=0.5",
43
+ "--max_duration_in_seconds=30.0",
44
+ "--use_auth_token",
45
+ "--seed=42",
46
+ "--group_by_length",
47
+ "--do_train",
48
+ "--do_eval",
49
+ "--push_to_hub",
50
+ "--preprocessing_num_workers=32",
51
+ "--ctc_zero_infinity",
52
+ "--do_lower_case",
53
+ "--wandb_project=wav2vec2",
54
+ "--wandb_name=wav2vec2-1b-npsc-nst",
55
+ "--remove_punctuation"
56
+ ],
57
+ "state": "running",
58
+ "program": "run_flax_speech_recognition_ctc.py",
59
+ "codePath": "run_flax_speech_recognition_ctc.py",
60
+ "git": {
61
+ "remote": "https://huggingface.co/NbAiLab/wav2vec2-1b-npsc-nst-tpu",
62
+ "commit": "e2b1320cc68c3ce129a1d654965e0d3eb44e0558"
63
+ },
64
+ "email": "versae@gmail.com",
65
+ "root": "/data/wav2vec2-1b-npsc-nst",
66
+ "host": "t1v-n-eedfb410-w-0",
67
+ "username": "javierr",
68
+ "executable": "/data/flax/bin/python"
69
+ }
wandb/run-20220805_114003-3f1x0hvu/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"_wandb": {"runtime": 716}}
wandb/run-20220805_114003-3f1x0hvu/logs/debug-internal.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eeba4400c2433c15bfaf1d69d20d5af658a679b9f4cf75713f00e9883d3620e0
3
+ size 42862
wandb/run-20220805_114003-3f1x0hvu/logs/debug.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:361ad7147270e83596fd398076999e7c2b9ad69eb4d2afc82355d01434e95eda
3
+ size 5845
wandb/run-20220805_114003-3f1x0hvu/run-3f1x0hvu.wandb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e68f244da66ea282bb58485e3253a0680c39e3f27e0d1c9cd1851021d6a64342
3
+ size 240293
wandb/run-20220805_124834-3ep5xqhh/files/code/run_flax_speech_recognition_ctc.py ADDED
@@ -0,0 +1,1631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2022 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the Flax library models for connectionist temporal classification (CTC) speech recognition.
18
+ """
19
+ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
+
21
+ import logging
22
+ import math
23
+ import os
24
+ import re
25
+ import sys
26
+ import time
27
+ from dataclasses import dataclass, field
28
+ from pathlib import Path
29
+ from typing import Any, Callable, Dict, List, Optional, Union
30
+
31
+ import datasets
32
+ import numpy as np
33
+ from datasets import DatasetDict, load_dataset, load_metric
34
+ from tqdm import tqdm
35
+
36
+ import flax
37
+ import jax
38
+ import jax.numpy as jnp
39
+ import optax
40
+ import transformers
41
+ import wandb as wandb
42
+ from flax import core, jax_utils, struct, traverse_util
43
+ from flax.jax_utils import unreplicate, pad_shard_unpad
44
+ from flax.training.common_utils import get_metrics, shard, shard_prng_key
45
+ from huggingface_hub import Repository
46
+ from models import Wav2Vec2Config, FlaxWav2Vec2ForCTC
47
+ from optax._src import linear_algebra
48
+ from transformers import (
49
+ AutoFeatureExtractor,
50
+ AutoProcessor,
51
+ AutoTokenizer,
52
+ HfArgumentParser,
53
+ TrainingArguments,
54
+ is_tensorboard_available,
55
+ set_seed,
56
+ )
57
+ from transformers.file_utils import get_full_repo_name
58
+ from transformers.utils import check_min_version
59
+ from transformers.utils.versions import require_version
60
+
61
+
62
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
63
+ check_min_version("4.17.0.dev0")
64
+
65
+ require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
66
+
67
+ logger = logging.getLogger(__name__)
68
+
69
+
70
+ @flax.struct.dataclass
71
+ class ModelArguments:
72
+ """
73
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
74
+ """
75
+
76
+ model_name_or_path: str = field(
77
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
78
+ )
79
+ config_name: Optional[str] = field(
80
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
81
+ )
82
+ tokenizer_name: Optional[str] = field(
83
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
84
+ )
85
+ feature_extractor_name: Optional[str] = field(
86
+ default=None, metadata={"help": "feature extractor name or path if not the same as model_name"}
87
+ )
88
+ cache_dir: Optional[str] = field(
89
+ default=None,
90
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
91
+ )
92
+ use_fast_tokenizer: bool = field(
93
+ default=True,
94
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
95
+ )
96
+ model_revision: str = field(
97
+ default="main",
98
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
99
+ )
100
+ use_auth_token: bool = field(
101
+ default=False,
102
+ metadata={
103
+ "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
104
+ "with private models)."
105
+ },
106
+ )
107
+ freeze_feature_encoder: bool = field(
108
+ default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
109
+ )
110
+ attention_dropout: float = field(
111
+ default=0.0, metadata={"help": "The dropout ratio for the attention probabilities."}
112
+ )
113
+ activation_dropout: float = field(
114
+ default=0.1,
115
+ metadata={
116
+ "help": "The hidden activation dropout probability in the embeddings, encoder, and pooler."
117
+ },
118
+ )
119
+ hidden_dropout: float = field(
120
+ default=0.1,
121
+ metadata={
122
+ "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler."
123
+ },
124
+ )
125
+ feat_proj_dropout: float = field(
126
+ default=0.0,
127
+ metadata={
128
+ "help": "The feat proj dropout probability for feature encoder representations."
129
+ },
130
+ )
131
+ final_dropout: float = field(
132
+ default=0.0,
133
+ metadata={"help": "The dropout probability for the final projection layer."},
134
+ )
135
+ mask_time_prob: float = field(
136
+ default=0.1,
137
+ metadata={
138
+ "help": "The spec aug dropout probability for feature encoder representations."
139
+ },
140
+ )
141
+ mask_time_length: int = field(
142
+ default=10,
143
+ metadata={"help": "Length of vector span to mask along the time axis."},
144
+ )
145
+ mask_feature_prob: float = field(
146
+ default=0.0,
147
+ metadata={
148
+ "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector"
149
+ "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis."
150
+ },
151
+ )
152
+ mask_feature_length: int = field(
153
+ default=10,
154
+ metadata={"help": "Length of vector span to mask along the feature axis."},
155
+ )
156
+ layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."})
157
+ ctc_loss_reduction: Optional[str] = field(
158
+ default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
159
+ )
160
+ ctc_zero_infinity: Optional[bool] = field(
161
+ default=False, metadata={"help": "If True, will try yo aboud the CTC loss goinf to infinity."}
162
+ )
163
+
164
+
165
+ @flax.struct.dataclass
166
+ class DataTrainingArguments:
167
+ """
168
+ Arguments pertaining to what data we are going to input our model for training and eval.
169
+ """
170
+
171
+ dataset_name: str = field(
172
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
173
+ )
174
+ dataset_config_name: Optional[str] = field(
175
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
176
+ )
177
+ text_column: Optional[str] = field(
178
+ default=None,
179
+ metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
180
+ )
181
+ dataset_cache_dir: Optional[str] = field(
182
+ default=None, metadata={"help": "Path to cache directory for saving and loading datasets"}
183
+ )
184
+ overwrite_cache: bool = field(
185
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
186
+ )
187
+ preprocessing_num_workers: Optional[int] = field(
188
+ default=None,
189
+ metadata={"help": "The number of processes to use for the preprocessing."},
190
+ )
191
+ max_train_samples: Optional[int] = field(
192
+ default=None,
193
+ metadata={
194
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
195
+ "value if set."
196
+ },
197
+ )
198
+ max_eval_samples: Optional[int] = field(
199
+ default=None,
200
+ metadata={
201
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
202
+ "value if set."
203
+ },
204
+ )
205
+ max_test_samples: Optional[int] = field(
206
+ default=None,
207
+ metadata={
208
+ "help": "For debugging purposes or quicker training, truncate the number of test examples to this "
209
+ "value if set."
210
+ },
211
+ )
212
+ audio_column_name: str = field(
213
+ default="audio",
214
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
215
+ )
216
+ text_column_name: str = field(
217
+ default="text",
218
+ metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
219
+ )
220
+ max_duration_in_seconds: float = field(
221
+ default=20.0,
222
+ metadata={
223
+ "help": "Truncate audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`"
224
+ },
225
+ )
226
+ min_duration_in_seconds: float = field(
227
+ default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
228
+ )
229
+ max_label_length: Optional[int] = field(
230
+ default=512,
231
+ metadata={
232
+ "help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
233
+ "than this will be filtered."
234
+ },
235
+ )
236
+ min_label_length: Optional[int] = field(
237
+ default=2,
238
+ metadata={
239
+ "help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
240
+ "than this will be filtered."
241
+ },
242
+ )
243
+ pad_input_to_multiple_of: Optional[int] = field(
244
+ default=32000,
245
+ metadata={
246
+ "help": "If set will pad the input sequence to a multiple of the provided value. "
247
+ "This is important to avoid triggering recompilations on TPU."
248
+ },
249
+ )
250
+ pad_target_to_multiple_of: Optional[int] = field(
251
+ default=None,
252
+ metadata={
253
+ "help": "If set will pad the target sequence to a multiple of the provided value. "
254
+ "This is important to avoid triggering recompilations on TPU."
255
+ },
256
+ )
257
+ preprocessing_only: bool = field(
258
+ default=False,
259
+ metadata={
260
+ "help": "Whether to only do data preprocessing and skip training. "
261
+ "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
262
+ "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
263
+ "so that the cached datasets can consequently be loaded in distributed training"
264
+ },
265
+ )
266
+ train_split_name: str = field(
267
+ default="train",
268
+ metadata={
269
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
270
+ },
271
+ )
272
+ eval_split_name: str = field(
273
+ default="validation",
274
+ metadata={
275
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
276
+ },
277
+ )
278
+ do_lower_case: bool = field(
279
+ default=True,
280
+ metadata={"help": "Whether the target text should be lower cased."},
281
+ )
282
+ wandb_project: str = field(
283
+ default="flax-speech-recognition-ctc",
284
+ metadata={"help": "The name of the wandb project."},
285
+ )
286
+ wandb_name: str = field(
287
+ default=None,
288
+ metadata={"help": "The name of the wandb run."},
289
+ )
290
+ wandb_job_type: str = field(
291
+ default="CTC",
292
+ metadata={"help": "The name of the wandb job type."},
293
+ )
294
+ test_split_name: str = field(
295
+ default="test",
296
+ metadata={"help": "The name of the test data set split to use (via the datasets library). Defaults to 'test'"},
297
+ )
298
+ remove_punctuation: bool = field(
299
+ default=False, metadata={"help": "Whether or not to remove punctuation during training."}
300
+ )
301
+ skip_steps: Optional[int] = field(
302
+ default=0,
303
+ metadata={
304
+ "help": "Skip this number of steps. Useful to continue training"
305
+ },
306
+ )
307
+
308
+
309
+ # @flax.struct.dataclass
310
+ @dataclass
311
+ class FlaxTrainingArguments(TrainingArguments):
312
+ precision: str = field(
313
+ default="full",
314
+ metadata={
315
+ "help": "Whether to enable mixed-precision training. If true, the optimizer is stored in half-precision (bfloat16) and computations are executed in half-precision"
316
+ "**Note that this only specifies the dtype of the computation and optimizer state. It does not influence the dtype of model parameters.**"
317
+ },
318
+ )
319
+ matmul_precision: str = field(
320
+ default="default",
321
+ metadata={
322
+ "help": "Default floating-point precision of internal computations used in TPU matrix multiplications and convolutions. "
323
+ "This configuration option controls the default precision for JAX operations that take an optional precision argument (e.g. `lax.conv_general_dilated` and `lax.dot`). "
324
+ "This configuration option does not change the behaviours of such calls with explicit precision arguments; "
325
+ "it only changes the behaviors of calls with no such argument provided. "
326
+ "One of `['highest', 'float32', 'high', 'bfloat16_3x', 'default', 'bfloat16', 'fastest', None]`."
327
+ },
328
+ )
329
+ multisteps: bool = field(
330
+ default=False,
331
+ metadata={
332
+ "help": "Whether to use Optax MultiSteps for gradient accumulation. If `False` (default) and `gradient_accumulation_steps > 1`, "
333
+ "a custom gradient accumulation implementation will be employed."
334
+ },
335
+ )
336
+
337
+
338
+ def to_fp32(t):
339
+ return jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
340
+
341
+
342
+ def to_bf16(t):
343
+ return jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t)
344
+
345
+
346
+ class MixedPrecisionTrainState(struct.PyTreeNode):
347
+ """Train state for use with a single Optax optimizer.
348
+ Adapted from flax train_state https://github.com/google/flax/blob/main/flax/training/train_state.py
349
+
350
+ Synopsis::
351
+
352
+ state = TrainState.create(
353
+ apply_fn=model.apply,
354
+ params=variables['params'],
355
+ tx=tx)
356
+ grad_fn = jax.grad(make_loss_fn(state.apply_fn))
357
+ for batch in data:
358
+ grads = grad_fn(state.params, batch)
359
+ state = state.apply_gradients(grads=grads)
360
+
361
+ Args:
362
+ step: Counter starts at 0 and is incremented by every call to
363
+ `.apply_gradients()`.
364
+ apply_fn: Usually set to `model.apply()`. Kept in this dataclass for
365
+ convenience to have a shorter params list for the `train_step()` function
366
+ in your training loop.
367
+ params: The parameters to be updated by `tx` and used by `apply_fn`.
368
+ tx: An Optax gradient transformation.
369
+ opt_state: The state for `tx`.
370
+ dropout_rng: PRNG key for stochastic operations.
371
+ bf16: Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training.
372
+ """
373
+
374
+ step: int
375
+ apply_fn: Callable = struct.field(pytree_node=False)
376
+ get_attention_mask_fn: Callable = struct.field(pytree_node=False)
377
+ params: core.FrozenDict[str, Any]
378
+ tx: optax.GradientTransformation = struct.field(pytree_node=False)
379
+ opt_state: optax.OptState
380
+ dropout_rng: jnp.ndarray
381
+ max_grad_norm: Optional[float] = 1.0
382
+
383
+ def apply_gradients(self, *, grads, to_dtype, **kwargs):
384
+ """Updates `step`, `params`, `opt_state` and `**kwargs` in return value.
385
+
386
+ Note that internally this function calls `.tx.update()` followed by a call
387
+ to `optax.apply_updates()` to update `params` and `opt_state`.
388
+
389
+ Args:
390
+ grads: Gradients that have the same pytree structure as `.params`.
391
+ **kwargs: Additional dataclass attributes that should be `.replace()`-ed.
392
+
393
+ Returns:
394
+ An updated instance of `self` with `step` incremented by one, `params`
395
+ and `opt_state` updated by applying `grads`, and additional attributes
396
+ replaced as specified by `kwargs`.
397
+ """
398
+
399
+ # clip gradients by global l2 norm
400
+ casted_max_grad_norm = to_dtype(self.max_grad_norm)
401
+ g_norm = linear_algebra.global_norm(grads)
402
+ g_norm = jnp.maximum(casted_max_grad_norm, g_norm)
403
+ grads = jax.tree_util.tree_map(lambda t: (t / g_norm) * casted_max_grad_norm, grads)
404
+
405
+ # perform update step in fp32 and subsequently downcast optimizer states if mixed precision training
406
+ # grads and opt_state in bf16 (need to upcast), params in fp32 (leave as is)
407
+ updates, new_opt_state = self.tx.update(to_fp32(grads), to_fp32(self.opt_state), self.params)
408
+
409
+ new_params = optax.apply_updates(self.params, updates)
410
+ return self.replace(
411
+ step=self.step + 1,
412
+ params=new_params,
413
+ opt_state=to_dtype(new_opt_state),
414
+ **kwargs,
415
+ )
416
+
417
+ @classmethod
418
+ def create(cls, *, apply_fn, params, tx, to_dtype, **kwargs):
419
+ """Creates a new instance with `step=0` and initialized `opt_state`."""
420
+ # downcast optimizer state to bf16 if mixed-precision training
421
+ opt_state = tx.init(to_dtype(params)) if tx is not None else None
422
+ return cls(
423
+ step=0,
424
+ apply_fn=apply_fn,
425
+ params=params,
426
+ tx=tx,
427
+ opt_state=opt_state,
428
+ **kwargs,
429
+ )
430
+
431
+ def replicate(self):
432
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
433
+
434
+
435
+ @flax.struct.dataclass
436
+ class FlaxDataCollatorSpeechSeq2SeqWithPadding:
437
+ """
438
+ Data collator that will dynamically pad the inputs received.
439
+ Args:
440
+ processor ([`Wav2Vec2Processor`])
441
+ The processor used for proccessing the data.
442
+ decoder_start_token_id (:obj: `int`)
443
+ The begin-of-sentence of the decoder.
444
+ input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
445
+ Select a strategy to pad the returned input sequences (according to the model's padding side and padding index)
446
+ among:
447
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
448
+ sequence if provided).
449
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
450
+ maximum acceptable input length for the model if that argument is not provided.
451
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
452
+ different lengths).
453
+ target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
454
+ Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
455
+ See above for details.
456
+ max_input_length (:obj:`float`, `optional`):
457
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
458
+ pad_input_to_multiple_of (:obj:`int`, `optional`):
459
+ If set will pad the input sequence to a multiple of the provided value.
460
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
461
+ 7.5 (Volta).
462
+ pad_target_to_multiple_of (:obj:`int`, `optional`):
463
+ If set will pad the target sequence to a multiple of the provided value.
464
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
465
+ 7.5 (Volta).
466
+ """
467
+
468
+ processor: Any
469
+ input_padding: Union[bool, str] = "longest"
470
+ label_padding: Union[bool, str] = "max_length"
471
+ pad_input_to_multiple_of: Optional[int] = None
472
+ pad_to_multiple_of_label: Optional[int] = None
473
+ max_input_length: Optional[float] = None
474
+ max_label_length: Optional[float] = None
475
+
476
+ def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
477
+ # split inputs and labels since they have to be of different lengths and need
478
+ # different padding methods
479
+ input_features = [{"input_values": feature["input_values"]} for feature in features]
480
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
481
+
482
+ # reformat list to dict and set to pytorch format
483
+ batch = self.processor.feature_extractor.pad(
484
+ input_features,
485
+ max_length=self.max_input_length,
486
+ padding=self.input_padding,
487
+ pad_to_multiple_of=self.pad_input_to_multiple_of,
488
+ return_tensors="np",
489
+ )
490
+
491
+ labels_batch = self.processor.tokenizer.pad(
492
+ label_features,
493
+ max_length=self.max_label_length,
494
+ padding=self.label_padding,
495
+ pad_to_multiple_of=self.pad_to_multiple_of_label,
496
+ return_tensors="np",
497
+ )
498
+
499
+ labels = labels_batch["input_ids"]
500
+ labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1))
501
+ labels = labels.filled(fill_value=-100)
502
+
503
+ batch["labels"] = labels
504
+
505
+ return batch
506
+
507
+
508
+ def get_grouped_indices(
509
+ dataset, batch_size: int, rng: Optional[List[int]] = None, mega_batch_mult: Optional[int] = None
510
+ ) -> np.array:
511
+ """
512
+ Adapted from the `get_length_grouped_indices` function in the PyTorch Trainer utils file (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_pt_utils.py#L486)
513
+ Function that returns a list of indices in which each slice of `batch_size` consecutive indices correspond to elements of similar
514
+ lengths. To do this, the indices are:
515
+
516
+ - randomly permuted (if a JAX rng is specified)
517
+ - grouped in mega-batches of size `mega_batch_mult * batch_size`
518
+ - sorted by length in each mega-batch
519
+
520
+ The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
521
+ maximum length placed first, so that an OOM happens sooner rather than later.
522
+ """
523
+ lengths = dataset["input_length"]
524
+
525
+ # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
526
+ if mega_batch_mult is None:
527
+ mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
528
+ # Just in case, for tiny datasets
529
+ if mega_batch_mult == 0:
530
+ mega_batch_mult = 1
531
+
532
+ # We need to use JAX for the random permutation as the PRNG key will be set based on the seed outside of the sampler.
533
+ num_samples = len(lengths)
534
+ indices = jax.random.permutation(rng, np.arange(num_samples)) if rng is not None else np.arange(num_samples)
535
+
536
+ megabatch_size = mega_batch_mult * batch_size
537
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
538
+ megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches]
539
+
540
+ # The rest is to get the biggest batch first.
541
+ # Since each megabatch is sorted by descending length, the longest element is the first
542
+ megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
543
+ max_idx = np.argmax(megabatch_maximums).item()
544
+ # Switch to put the longest batch in first position
545
+ # (note that this is different to the PT grouped sampler in which we only put the longest element in the first position, and not its batch)
546
+ megabatches[0], megabatches[max_idx] = megabatches[max_idx], megabatches[0]
547
+
548
+ megabatches = np.array([i for megabatch in megabatches for i in megabatch])
549
+
550
+ return megabatches
551
+
552
+
553
+ def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray:
554
+ """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
555
+ the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned."""
556
+ num_samples = len(samples_idx)
557
+ if drop_last:
558
+ samples_to_remove = num_samples % batch_size
559
+ if samples_to_remove != 0:
560
+ samples_idx = samples_idx[:-samples_to_remove]
561
+ sections_split = num_samples // batch_size
562
+ samples_idx = samples_idx.reshape((sections_split, batch_size))
563
+ else:
564
+ sections_split = math.ceil(num_samples / batch_size)
565
+ samples_idx = np.array_split(samples_idx, sections_split)
566
+ return samples_idx
567
+
568
+
569
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
570
+ summary_writer.scalar("train_time", train_time, step)
571
+
572
+ train_metrics = get_metrics(train_metrics)
573
+ for key, vals in train_metrics.items():
574
+ tag = f"train_{key}"
575
+ for i, val in enumerate(vals):
576
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
577
+
578
+
579
+ def write_eval_metric(summary_writer, eval_metrics, step, pred_str=None):
580
+ for metric_name, value in eval_metrics.items():
581
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
582
+
583
+ if pred_str is not None:
584
+ # write output actual predictions for debugging
585
+ summary_writer.text("eval_predictions", "\n".join(pred_str), step)
586
+
587
+
588
+ def write_wandb_log(metrics, step, prefix=None):
589
+ if jax.process_index() == 0:
590
+ log_metrics = {}
591
+ for k, v in metrics.items():
592
+ if "layer" in k:
593
+ log_metrics[f"{k}/"] = v
594
+ elif prefix is not None:
595
+ log_metrics[f"{prefix}/{k}"] = v
596
+ else:
597
+ log_metrics[k] = v
598
+ wandb.log(log_metrics, step)
599
+
600
+
601
+ def write_wandb_pred(pred_str, label_str, step, num_log=50, prefix="eval"):
602
+ if jax.process_index() == 0:
603
+ # convert str data to a wandb compatible format
604
+ str_data = [[label_str[i], pred_str[i]] for i in range(len(pred_str))]
605
+ # we'll log the first 50 predictions for each epoch
606
+ wandb.log(
607
+ {
608
+ f"{prefix}/step_{int(step / 1000)}k": wandb.Table(
609
+ columns=["label_str", "pred_str"], data=str_data[:num_log]
610
+ )
611
+ },
612
+ step,
613
+ )
614
+
615
+
616
+ def create_learning_rate_fn(
617
+ num_train_steps: int, num_warmup_steps: int, learning_rate: float
618
+ ) -> Callable[[int], jnp.array]:
619
+ """Returns a linear warmup, linear_decay learning rate function."""
620
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
621
+ decay_fn = optax.linear_schedule(
622
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
623
+ )
624
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
625
+ return schedule_fn
626
+
627
+
628
+ def ctc_loss(
629
+ logits,
630
+ logits_attention_mask,
631
+ labels,
632
+ blank_id,
633
+ loss_reduction="mean",
634
+ output_emission_dict=False,
635
+ log_epsilon=-100000.0,
636
+ ):
637
+ """Computes CTC loss.
638
+ This function performs forward computation over an FSA with `N * 2` states
639
+ where `N` is the max number of labels. The states are split into two groups:
640
+ Phi states and emission states. a phi-state accepts repetition of
641
+ phi (blank)-symbols and transits to emission state when the correct label is
642
+ observed. An emission state accepts repetition of the label and transits to
643
+ the next phi states at any time (so called epsilon-transition).
644
+ Below, `B` denotes the batch size, `T` denotes the time steps in `logits`,
645
+ and `N` denotes the time steps in `labels`.
646
+ Args:
647
+ logits: (B, T, K)-array containing log-probabilities of each class.
648
+ logitpaddings: (B, T)-array. Padding indicators for `logits`.
649
+ labels: (B, N)-array containing reference integer labels.
650
+ labelpaddings: (B, N)-array. Padding indicators for `labels`. Currently,
651
+ `labels` must be right-padded, i.e. each row of `labelpaddings` must be
652
+ repetition of zeroes, followed by repetition of ones.
653
+ blank_id: Id for blank token.
654
+ loss_reduction: one of "mean", "sum", "default"
655
+ - "none": no reduction is applied.
656
+ - "mean": output loss will be divided by target lengths and then the
657
+ mean over the batch is taken.
658
+ - "sum": output loss are summed over batch
659
+ output_emission_dict: whether to output additional information about the emission probs
660
+ Returns:
661
+ A pair of `(per_seq_loss, aux)`.
662
+ per_seq_loss:
663
+ (B,)-array containing loss values for each sequence in the batch.
664
+ aux: Dictionary containing interim variables used for computing losses.
665
+ aux['logalpha_phi']: (T, B, N+1)-array. Log-forward-probabilities of each
666
+ phi-state corresponding to the n-th label.
667
+ aux['logalpha_emit']: (T, B, N)-array. Log-forward-probabilities of each
668
+ emission-state corresponding to the n-th label.
669
+ aux['logprobs_phi']: (T, B, 1)-array. Probability of the phi-symbol
670
+ corresponding to each time frame.
671
+ aux['logprobs_emit']: (T, B, N)-array. Probability of the n-th label
672
+ corresponding to each time frame.
673
+ """
674
+ # label paddings are indicated by -100
675
+ labelpaddings = labels < 0
676
+ # logit paddings are the inverse of attention_mask
677
+ logitpaddings = ~logits_attention_mask
678
+
679
+ # Copied from https://github.com/tensorflow/lingvo/blob/master/lingvo/jax/layers/ctc_objectives.py
680
+ batchsize, unused_maxinputlen, num_classes = logits.shape
681
+ batchsize_, maxlabellen = labels.shape
682
+
683
+ logprobs = jax.nn.log_softmax(logits)
684
+ labellens = maxlabellen - jnp.sum(labelpaddings, axis=1).astype(jnp.int32)
685
+
686
+ # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1].
687
+ repeat = (labels[:, :-1] == labels[:, 1:]).astype(jnp.float32)
688
+ repeat = jnp.pad(repeat, ((0, 0), (0, 1)))
689
+
690
+ logprobs_phi = logprobs[:, :, blank_id : blank_id + 1] # [B, T, 1]
691
+ logprobs_phi = jnp.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1]
692
+
693
+ one_hot = jax.nn.one_hot(labels, num_classes=num_classes) # [B, N, K]
694
+ logprobs_emit = jnp.einsum("btk,bnk->btn", logprobs, one_hot)
695
+ logprobs_emit = jnp.transpose(logprobs_emit, (1, 0, 2)) # [T, B, N]
696
+
697
+ logalpha_phi_init = jnp.ones((batchsize, maxlabellen + 1)) * log_epsilon # [B, N]
698
+ logalpha_phi_init = logalpha_phi_init.at[:, 0].set(0.0)
699
+ logalpha_emit_init = jnp.ones((batchsize, maxlabellen)) * log_epsilon # [B, N]
700
+
701
+ def loop_body(prev, x):
702
+ prev_phi, prev_emit = prev
703
+ # emit-to-phi epsilon transition, except if the next label is repetition
704
+ prev_phi_orig = prev_phi
705
+ prev_phi = prev_phi.at[:, 1:].set(jnp.logaddexp(prev_phi[:, 1:], prev_emit + log_epsilon * repeat))
706
+
707
+ logprob_emit, logprob_phi, pad = x
708
+
709
+ # phi-to-emit transition
710
+ next_emit = jnp.logaddexp(prev_phi[:, :-1] + logprob_emit, prev_emit + logprob_emit)
711
+ # self-loop transition
712
+ next_phi = prev_phi + logprob_phi
713
+ # emit-to-phi blank transition only when the next label is repetition
714
+ next_phi = next_phi.at[:, 1:].set(
715
+ jnp.logaddexp(next_phi[:, 1:], prev_emit + logprob_phi + log_epsilon * (1.0 - repeat))
716
+ )
717
+
718
+ pad = pad.reshape((batchsize, 1))
719
+ next_emit = pad * prev_emit + (1.0 - pad) * next_emit
720
+ next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi
721
+
722
+ return (next_phi, next_emit), (next_phi, next_emit)
723
+
724
+ xs = (logprobs_emit, logprobs_phi, logitpaddings.transpose((1, 0)))
725
+ _, (logalpha_phi, logalpha_emit) = jax.lax.scan(loop_body, (logalpha_phi_init, logalpha_emit_init), xs)
726
+
727
+ # last row needs to be updated with the last epsilon transition
728
+ logalpha_phi_last = logalpha_phi[-1].at[:, 1:].set(jnp.logaddexp(logalpha_phi[-1, :, 1:], logalpha_emit[-1]))
729
+ logalpha_phi = logalpha_phi.at[-1].set(logalpha_phi_last)
730
+
731
+ # extract per_seq_loss
732
+ one_hot = jax.nn.one_hot(labellens, num_classes=maxlabellen + 1) # [B, N+1]
733
+ per_seq_loss = -jnp.einsum("bn,bn->b", logalpha_phi_last, one_hot)
734
+
735
+ if loss_reduction == "mean":
736
+ target_lengths = labelpaddings.shape[-1] - labelpaddings.sum(axis=-1)
737
+ loss = (per_seq_loss / target_lengths).mean()
738
+ elif loss_reduction == "sum":
739
+ loss = per_seq_loss.sum()
740
+ else:
741
+ loss = per_seq_loss
742
+
743
+ if not output_emission_dict:
744
+ return loss
745
+
746
+ return loss, {
747
+ "logalpha_phi": logalpha_phi,
748
+ "logalpha_emit": logalpha_emit,
749
+ "logprobs_phi": logprobs_phi,
750
+ "logprobs_emit": logprobs_emit,
751
+ }
752
+
753
+
754
+ def make_dataset(data_args, seed=42):
755
+ # Pre-processing dataset
756
+ import re
757
+
758
+ def map_nst(entry):
759
+ text = entry["text"].lower()
760
+ text = text.replace("(...vær stille under dette opptaket...)", "")
761
+ text = re.sub('[áàâ]', 'a', text)
762
+ text = re.sub('[ä]', 'æ', text)
763
+ text = re.sub('[éèëê]', 'e', text)
764
+ text = re.sub('[íìïî]', 'i', text)
765
+ text = re.sub('[óòöô]', 'o', text)
766
+ text = re.sub('[ö]', 'ø', text)
767
+ text = re.sub('[ç]', 'c', text)
768
+ text = re.sub('[úùüû]', 'u', text)
769
+ # text = re.sub('\\(?=(Punktum|Komma|Utropstegn|Spørsmålstegn))', ' ', text)
770
+ text = re.sub('\s+', ' ', text)
771
+ return {"text": text}
772
+
773
+ def filter_nst(entry):
774
+ if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)):
775
+ return False # Too short
776
+ if re.match(entry["type"], "pIW|CA"):
777
+ return False # Spelling out words
778
+ return True
779
+
780
+ def filter_npsc(entry):
781
+ # False if there are digits in the text
782
+ if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)):
783
+ return False # Too short
784
+ if re.search("\d", entry["text"]):
785
+ return False
786
+ return True
787
+
788
+ def map_npsc(entry):
789
+ batch = {"text": entry["text"].lower()}
790
+ batch["text"] = re.sub('[áàâ]', 'a', batch["text"])
791
+ batch["text"] = re.sub('[ä]', 'æ', batch["text"])
792
+ batch["text"] = re.sub('[éèëê]', 'e', batch["text"])
793
+ batch["text"] = re.sub('[íìïî]', 'i', batch["text"])
794
+ batch["text"] = re.sub('[óòöô]', 'o', batch["text"])
795
+ batch["text"] = re.sub('[ö]', 'ø', batch["text"])
796
+ batch["text"] = re.sub('[ç]', 'c', batch["text"])
797
+ batch["text"] = re.sub('[úùüû]', 'u', batch["text"])
798
+ batch["text"] = re.sub('\s', ' ', batch["text"])
799
+ batch["text"] = re.sub('<ee>', 'eee', batch["text"])
800
+ batch["text"] = re.sub('<qq>', 'qqq', batch["text"])
801
+ batch["text"] = re.sub('<mm>', 'mmm', batch["text"])
802
+ batch["text"] = re.sub('<inaudible>', 'xxx', batch["text"])
803
+ # batch["text"] = re.sub('<inaudible>', '?', batch["text"])
804
+ if "<" in batch["text"]:
805
+ raise ValueError(batch["text"])
806
+ return batch
807
+
808
+ nst = datasets.load_dataset("NbAiLab/NST", "no-close")
809
+ npsc = datasets.load_dataset("NbAiLab/NPSC", "16K_mp3")
810
+ # TODO NST_hesitate
811
+
812
+ split = len(npsc[data_args.train_split_name]) / (len(npsc[data_args.train_split_name]) + len(npsc[data_args.eval_split_name])) # Use same train/val ratio as NPSC
813
+ nst_train = nst[data_args.train_split_name].train_test_split(train_size=split, seed=seed)
814
+ nst[data_args.train_split_name] = nst_train["train"]
815
+ nst[data_args.eval_split_name] = nst_train["test"]
816
+
817
+ nst = nst.filter(filter_nst).map(
818
+ map_nst,
819
+ num_proc=data_args.preprocessing_num_workers,
820
+ desc="filtering NST",
821
+ ).shuffle(seed=seed)
822
+ npsc = npsc.filter(filter_npsc).map(
823
+ map_npsc,
824
+ num_proc=data_args.preprocessing_num_workers,
825
+ desc="filtering NPSC",
826
+ ).shuffle(seed=seed)
827
+
828
+ npsc_base = npsc.remove_columns([col for col in npsc[data_args.train_split_name].column_names if col not in ["text", "audio"]])
829
+ nst_base = nst.remove_columns([col for col in nst[data_args.train_split_name].column_names if col not in ["text", "audio"]])
830
+
831
+ combined = {}
832
+ for split in data_args.train_split_name, data_args.eval_split_name, data_args.test_split_name:
833
+ probs = np.array([len(nst_base[split]), len(npsc_base[split])]) # Weight by number of examples
834
+ probs = (probs / probs.sum()).tolist()
835
+ comb = datasets.interleave_datasets([nst_base[split], npsc_base[split]], probabilities=probs, seed=seed)
836
+ combined[split] = comb
837
+
838
+ return datasets.DatasetDict(**combined)
839
+
840
+ def main():
841
+ # 1. Parse input arguments
842
+ # See all possible arguments in src/transformers/training_args.py
843
+ # or by passing the --help flag to this script.
844
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
845
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FlaxTrainingArguments))
846
+
847
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
848
+ # If we pass only one argument to the script and it's the path to a json file,
849
+ # let's parse it to get our arguments.
850
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
851
+ else:
852
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
853
+
854
+ # 2. Setup logging
855
+ # Make one log on every process with the configuration for debugging.
856
+ logging.basicConfig(
857
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
858
+ datefmt="%m/%d/%Y %H:%M:%S",
859
+ handlers=[logging.StreamHandler(sys.stdout)],
860
+ )
861
+ # Set the verbosity to info of the Transformers logger.
862
+ # We only want one process per machine to log things on the screen.
863
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
864
+ if jax.process_index() == 0:
865
+ datasets.utils.logging.set_verbosity_warning()
866
+ transformers.utils.logging.set_verbosity_info()
867
+ else:
868
+ datasets.utils.logging.set_verbosity_error()
869
+ transformers.utils.logging.set_verbosity_error()
870
+
871
+ # Set up wandb run
872
+ if jax.process_index() == 0:
873
+ wandb.init(project=data_args.wandb_project, name=data_args.wandb_name, job_type=data_args.wandb_job_type)
874
+
875
+ logger.info("Training/evaluation parameters %s", training_args)
876
+
877
+ # Set the default TPU matmul precision and display the number of devices
878
+ jax.config.update("jax_default_matmul_precision", training_args.matmul_precision)
879
+ logger.info(f"JAX devices: {jax.device_count()}, matmul precision: {training_args.matmul_precision}")
880
+
881
+ # 4. Load dataset
882
+
883
+ set_seed(training_args.seed)
884
+ raw_datasets = make_dataset(data_args, seed=training_args.seed)
885
+
886
+ # raw_datasets = DatasetDict()
887
+
888
+ # if training_args.do_train:
889
+ # raw_datasets[data_args.train_split_name] = load_dataset(
890
+ # data_args.dataset_name,
891
+ # data_args.dataset_config_name,
892
+ # split=data_args.train_split_name,
893
+ # cache_dir=data_args.dataset_cache_dir,
894
+ # use_auth_token=True if model_args.use_auth_token else None,
895
+ # )
896
+
897
+ # if training_args.do_eval:
898
+ # raw_datasets[data_args.eval_split_name] = load_dataset(
899
+ # data_args.dataset_name,
900
+ # data_args.dataset_config_name,
901
+ # split=data_args.eval_split_name,
902
+ # cache_dir=data_args.dataset_cache_dir,
903
+ # use_auth_token=True if model_args.use_auth_token else None,
904
+ # )
905
+
906
+ # if training_args.do_predict:
907
+ # test_split = data_args.test_split_name.split("+")
908
+ # for split in test_split:
909
+ # raw_datasets[split] = load_dataset(
910
+ # data_args.dataset_name,
911
+ # data_args.dataset_config_name,
912
+ # split=split,
913
+ # cache_dir=data_args.dataset_cache_dir,
914
+ # use_auth_token=True if model_args.use_auth_token else None,
915
+ # )
916
+
917
+ if not training_args.do_train and not training_args.do_eval and not training_args.do_predict:
918
+ raise ValueError(
919
+ "Cannot not train, not do evaluation and not do prediction. At least one of "
920
+ "training, evaluation or prediction has to be done."
921
+ )
922
+
923
+ # if not training, there is no need to run multiple epochs
924
+ if not training_args.do_train:
925
+ training_args.num_train_epochs = 1
926
+
927
+ if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
928
+ raise ValueError(
929
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
930
+ "Make sure to set `--audio_column_name` to the correct audio column - one of "
931
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
932
+ )
933
+
934
+ if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names:
935
+ raise ValueError(
936
+ f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
937
+ "Make sure to set `--text_column_name` to the correct text column - one of "
938
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
939
+ )
940
+
941
+ # 5. Load pretrained model, tokenizer, and feature extractor
942
+ #
943
+ # Distributed training:
944
+ # The .from_pretrained methods guarantee that only one local process can concurrently
945
+ config = Wav2Vec2Config.from_pretrained(
946
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
947
+ cache_dir=model_args.cache_dir,
948
+ revision=model_args.model_revision,
949
+ use_auth_token=True if model_args.use_auth_token else None,
950
+ )
951
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
952
+ model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
953
+ cache_dir=model_args.cache_dir,
954
+ revision=model_args.model_revision,
955
+ use_auth_token=True if model_args.use_auth_token else None,
956
+ )
957
+ tokenizer = AutoTokenizer.from_pretrained(
958
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
959
+ cache_dir=model_args.cache_dir,
960
+ revision=model_args.model_revision,
961
+ use_auth_token=True if model_args.use_auth_token else None,
962
+ )
963
+ # update config according to training args, model args, and tokenizer attributes
964
+ config.update(
965
+ {
966
+ "feat_proj_dropout": model_args.feat_proj_dropout,
967
+ "attention_dropout": model_args.attention_dropout,
968
+ "hidden_dropout": model_args.hidden_dropout,
969
+ "final_dropout": model_args.final_dropout,
970
+ "mask_time_prob": model_args.mask_time_prob,
971
+ "mask_time_length": model_args.mask_time_length,
972
+ "mask_feature_prob": model_args.mask_feature_prob,
973
+ "mask_feature_length": model_args.mask_feature_length,
974
+ "gradient_checkpointing": training_args.gradient_checkpointing,
975
+ "layerdrop": model_args.layerdrop,
976
+ "ctc_loss_reduction": model_args.ctc_loss_reduction,
977
+ "ctc_zero_infinity": model_args.ctc_zero_infinity,
978
+ "pad_token_id": tokenizer.pad_token_id,
979
+ "vocab_size": tokenizer.vocab_size, # len(tokenizer),
980
+ "activation_dropout": model_args.activation_dropout,
981
+ }
982
+ )
983
+
984
+ if tokenizer.do_lower_case and data_args.dataset_name != "librispeech_asr":
985
+ raise ValueError(
986
+ "Setting the tokenizer attribute `do_lower_case` to `True` converts all input strings to "
987
+ "uppercase prior to tokenization. This should only be done when the tokenizer is built on an uppercased corpus,"
988
+ "i.e. for the dataset `librispeech_asr` only. If your dataset is not `librispeech_asr`, the tokenizer is mostly likely "
989
+ "built on an lowercased corpus. In this case, set `tokenizer.do_lower_case` to ``False`."
990
+ )
991
+
992
+ if training_args.precision == "full_mixed":
993
+ dtype = jnp.bfloat16
994
+ training_args.mixed_precision = True
995
+ elif training_args.precision == "half_mixed":
996
+ dtype = jnp.bfloat16
997
+ training_args.mixed_precision = False
998
+ else:
999
+ dtype = jnp.float32
1000
+ training_args.mixed_precision = False
1001
+
1002
+ try:
1003
+ model = FlaxWav2Vec2ForCTC.from_pretrained(
1004
+ model_args.model_name_or_path,
1005
+ config=config,
1006
+ dtype=dtype,
1007
+ cache_dir=model_args.cache_dir,
1008
+ revision=model_args.model_revision,
1009
+ use_auth_token=True if model_args.use_auth_token else None,
1010
+ )
1011
+ except:
1012
+ model = FlaxWav2Vec2ForCTC.from_pretrained(
1013
+ model_args.model_name_or_path,
1014
+ config=config,
1015
+ dtype=dtype,
1016
+ cache_dir=model_args.cache_dir,
1017
+ revision=model_args.model_revision,
1018
+ use_auth_token=True if model_args.use_auth_token else None,
1019
+ from_pt=True,
1020
+ )
1021
+
1022
+ # 6. Resample speech dataset ALWAYS
1023
+ raw_datasets = raw_datasets.cast_column(
1024
+ data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
1025
+ )
1026
+
1027
+ # 7. Preprocessing the datasets.
1028
+ # We need to read the audio files as arrays and tokenize the targets.
1029
+ max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
1030
+ min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate)
1031
+ max_target_length = data_args.max_label_length
1032
+ min_target_length = data_args.min_label_length
1033
+ pad_input_to_multiple_of = data_args.pad_input_to_multiple_of
1034
+ audio_column_name = data_args.audio_column_name
1035
+ num_workers = data_args.preprocessing_num_workers
1036
+ text_column_name = data_args.text_column_name
1037
+ model_input_name = feature_extractor.model_input_names[0]
1038
+ do_lower_case = data_args.do_lower_case
1039
+ dataset_name = data_args.dataset_name
1040
+ chars_to_ignore = ', ? . ! - ; : " “ % ‘ ” ?'.split(" ")
1041
+ chars_to_ignore_regex = f'[{"".join(chars_to_ignore)}]'
1042
+ # gigaspeech_punctuation = {" <comma>": ",", " <period>": ".", " <questionmark>": "?", " <exclamationpoint>": "!"}
1043
+ # gigaspeech_disfluencies = ["<other>", "<sil>"]
1044
+ # swb_disfluencies = ["[noise]", "[laughter]", "[silence]", "<a_aside>", "<b_aside>", "<e_aside>", "[laughter-",
1045
+ # "[vocalized-noise]", "_1"]
1046
+ # swb_punctuations = ["{", "}", "[", "]-", "]"]
1047
+ # earnings_disfluencies = ["<crosstalk>", "<affirmative>", "<inaudible>", "inaudible", "<laugh>", "<unk>"]
1048
+ ignore_segments = ["ignore_time_segment_in_scoring", "<noise>", "<music>", "[noise]", "[laughter]", "[silence]",
1049
+ "[vocalized-noise]", "<crosstalk>", "<affirmative>", "<inaudible>", "<laugh>", "<other>", "<sil>", ""]
1050
+
1051
+ if training_args.do_train and data_args.max_train_samples is not None:
1052
+ raw_datasets[data_args.train_split_name] = raw_datasets[data_args.train_split_name].select(range(data_args.max_train_samples))
1053
+
1054
+ if training_args.do_eval and data_args.max_eval_samples is not None:
1055
+ raw_datasets[data_args.eval_split_name] = raw_datasets[data_args.eval_split_name].select(range(data_args.max_eval_samples))
1056
+
1057
+ if training_args.do_predict and data_args.max_test_samples is not None:
1058
+ raw_datasets[data_args.test_split_name] = raw_datasets[data_args.test_split_name].select(range(data_args.max_eval_samples))
1059
+
1060
+ if training_args.do_train and data_args.remove_punctuation:
1061
+
1062
+ def remove_punctuation(batch):
1063
+ batch[text_column_name] = (
1064
+ re.sub(chars_to_ignore_regex, "", batch[text_column_name]).replace("'", "").replace('"', "")
1065
+ )
1066
+
1067
+ raw_datasets[data_args.train_split_name] = raw_datasets[data_args.train_split_name].map(
1068
+ remove_punctuation,
1069
+ num_proc=data_args.preprocessing_num_workers,
1070
+ desc="removing punctuation from train split",
1071
+ )
1072
+
1073
+ # filter data where the targets are ignored in scoring
1074
+ def is_target_labels(input_str):
1075
+ return input_str.lower() not in ignore_segments
1076
+
1077
+ raw_datasets = raw_datasets.filter(
1078
+ is_target_labels,
1079
+ num_proc=num_workers,
1080
+ input_columns=[text_column_name],
1081
+ desc="filtering data where the targets are ignored in scoring",
1082
+ )
1083
+
1084
+ def prepare_dataset(batch):
1085
+ # process audio
1086
+ try:
1087
+ sample = batch[audio_column_name]
1088
+ except ValueError:
1089
+ sample = {"array": np.array([0.]), "sampling_rate": feature_extractor.sampling_rate}
1090
+ inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
1091
+ # process audio length
1092
+ batch[model_input_name] = inputs.input_values[0]
1093
+ batch["input_length"] = len(batch["input_values"])
1094
+
1095
+ # process targets
1096
+ input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
1097
+
1098
+ # if dataset_name == "google/xtreme_s":
1099
+ # # Finally, we tokenize the processed text
1100
+ # batch["labels"] = tokenizer(input_str).input_ids
1101
+ # batch["labels_length"] = len(batch["labels"])
1102
+ # return batch
1103
+
1104
+ # # Common Voice 9
1105
+ # if input_str.startswith('"') and input_str.endswith('"'):
1106
+ # # we can remove trailing quotation marks as they do not affect the transcription
1107
+ # input_str = input_str[1:-1]
1108
+ # # normalize quotation marks
1109
+ # input_str = re.sub(r'["“”]', '"', input_str)
1110
+ # # normalize apostrophes
1111
+ # input_str = re.sub(r"[’']", "'", input_str)
1112
+ # # normalize hyphens
1113
+ # input_str = re.sub(r"[—–]", "-", input_str)
1114
+ # # replace double quotation marks with single
1115
+ # input_str = input_str.replace('""', '"')
1116
+ # if dataset_name == "mozilla-foundation/common_voice_9_0" and len(input_str):
1117
+ # # for CV9, we'll normalize the text to always finish with punctuation
1118
+ # if input_str[-1] not in [".", "?", "!"]:
1119
+ # input_str = input_str + "."
1120
+
1121
+ # # TEDLIUM-3
1122
+ # # delete the <unk> token from the text and replace spaced apostrophes with un-spaced
1123
+ # input_str = input_str.replace("<unk>", "").replace(" '", "'")
1124
+
1125
+ # # GigaSpeech
1126
+ # for disfluency in gigaspeech_disfluencies:
1127
+ # input_str = input_str.replace(disfluency, "")
1128
+ # # convert spelled out punctuation to symbolic form
1129
+ # for punctuation, replacement in gigaspeech_punctuation.items():
1130
+ # input_str = input_str.replace(punctuation, replacement)
1131
+ # if dataset_name == "speechcolab/gigaspeech" and len(input_str):
1132
+ # # for GS, we'll normalize the text to always finish with punctuation
1133
+ # if input_str[-1] not in [".", "?", "!"]:
1134
+ # input_str = input_str + "."
1135
+
1136
+ # # SWB
1137
+ # for disfluency in swb_disfluencies:
1138
+ # input_str = input_str.replace(disfluency, "")
1139
+ # # remove parenthesised text (test data only)
1140
+ # input_str = re.sub("[\(].*?[\)]", "", input_str)
1141
+ # for punctuation in swb_punctuations:
1142
+ # input_str = input_str.replace(punctuation, "")
1143
+ # # replace anomalous words with their correct transcriptions
1144
+ # split_str = input_str.split("/")
1145
+ # if len(split_str) > 1:
1146
+ # input_str = " ".join(
1147
+ # [" ".join([" ".join(i.split(" ")[:-1]) for i in split_str])] + [split_str[-1].split(" ")[-1]])
1148
+
1149
+ # # Earnings 22
1150
+ # for disfluency in earnings_disfluencies:
1151
+ # input_str = input_str.replace(disfluency, "")
1152
+ # # replace mal-formatted ellipsis
1153
+ # input_str = input_str.replace("…", ".")
1154
+
1155
+ # JIWER compliance
1156
+ # remove multiple spaces
1157
+ input_str = re.sub(r"\s\s+", " ", input_str)
1158
+ # strip trailing spaces
1159
+ input_str = input_str.strip()
1160
+
1161
+ # Finally, we tokenize the processed text
1162
+ batch["labels"] = tokenizer(input_str).input_ids
1163
+ batch["labels_length"] = len(batch["labels"])
1164
+ return batch
1165
+
1166
+ vectorized_datasets = raw_datasets.map(
1167
+ prepare_dataset,
1168
+ remove_columns=next(iter(raw_datasets.values())).column_names,
1169
+ num_proc=num_workers,
1170
+ desc="preprocess dataset",
1171
+ )
1172
+
1173
+ # filter data with inputs shorter than min_input_length or longer than max_input_length
1174
+ def is_audio_in_length_range(length):
1175
+ return length > min_input_length and length < max_input_length
1176
+
1177
+ vectorized_datasets = vectorized_datasets.filter(
1178
+ is_audio_in_length_range,
1179
+ num_proc=num_workers,
1180
+ input_columns=["input_length"],
1181
+ )
1182
+
1183
+ # filter data with targets shorter than min_target_length or longer than max_target_length
1184
+ def is_labels_in_length_range(length):
1185
+ return length > min_target_length # and length < max_target_length
1186
+
1187
+ vectorized_datasets = vectorized_datasets.filter(
1188
+ is_labels_in_length_range,
1189
+ num_proc=num_workers,
1190
+ input_columns=["labels_length"],
1191
+ )
1192
+
1193
+ # for large datasets it is advised to run the preprocessing on a
1194
+ # single machine first with `args.preprocessing_only` since there will mostly likely
1195
+ # be a timeout when running the script in distributed mode.
1196
+ # In a second step `args.preprocessing_only` can then be set to `False` to load the
1197
+ # cached dataset
1198
+ if data_args.preprocessing_only:
1199
+ cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
1200
+ logger.info(f"Data preprocessing finished. Files cached at {cache}.")
1201
+ return
1202
+
1203
+ # 8. Load Metrics
1204
+ wer_metric = load_metric("wer")
1205
+ cer_metric = load_metric("cer")
1206
+
1207
+ def compute_metrics(pred_ids: List[List[int]], label_ids: List[List[int]]):
1208
+ padded_ids = np.where(np.asarray(label_ids) == -100, tokenizer.pad_token_id, np.asarray(label_ids))
1209
+
1210
+ pred_str = tokenizer.batch_decode(pred_ids)
1211
+ # we do not want to group tokens when computing the metrics
1212
+ label_str = tokenizer.batch_decode(padded_ids, group_tokens=False)
1213
+
1214
+ wer = wer_metric.compute(predictions=pred_str, references=label_str)
1215
+ cer = cer_metric.compute(predictions=pred_str, references=label_str)
1216
+
1217
+ return {"wer": wer, "cer": cer}, pred_str, label_str
1218
+
1219
+ # 9. save feature extractor, tokenizer and config
1220
+ feature_extractor.save_pretrained(training_args.output_dir)
1221
+ tokenizer.save_pretrained(training_args.output_dir)
1222
+ config.save_pretrained(training_args.output_dir)
1223
+
1224
+ processor = AutoProcessor.from_pretrained(training_args.output_dir)
1225
+
1226
+ data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding(
1227
+ processor=processor,
1228
+ input_padding="longest",
1229
+ pad_input_to_multiple_of=pad_input_to_multiple_of,
1230
+ max_label_length=data_args.max_label_length,
1231
+ )
1232
+
1233
+ # Enable tensorboard only on the master node
1234
+ has_tensorboard = is_tensorboard_available()
1235
+ if has_tensorboard and jax.process_index() == 0:
1236
+ try:
1237
+ from flax.metrics.tensorboard import SummaryWriter
1238
+
1239
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
1240
+ except ImportError as ie:
1241
+ has_tensorboard = False
1242
+ logger.warning(
1243
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
1244
+ )
1245
+ else:
1246
+ logger.warning(
1247
+ "Unable to display metrics through TensorBoard because the package is not installed: "
1248
+ "Please run `pip install tensorboard` to enable."
1249
+ )
1250
+
1251
+ # 10. Handle the repository creation
1252
+ if training_args.push_to_hub:
1253
+ with open(os.path.join(training_args.output_dir, ".gitattributes"), "r+") as f:
1254
+ git_lfs_extensions = f.read()
1255
+ if "*.wandb" not in git_lfs_extensions:
1256
+ f.write("*.wandb filter=lfs diff=lfs merge=lfs -text")
1257
+ if training_args.hub_model_id is None:
1258
+ repo_name = get_full_repo_name(
1259
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
1260
+ )
1261
+ else:
1262
+ repo_name = training_args.hub_model_id
1263
+ repo = Repository(training_args.output_dir, clone_from=repo_name)
1264
+
1265
+ # 11. Initialize our training
1266
+ rng = jax.random.PRNGKey(training_args.seed)
1267
+ rng, dropout_rng = jax.random.split(rng)
1268
+
1269
+ # Store some constants
1270
+ max_steps = int(training_args.max_steps)
1271
+ gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
1272
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
1273
+ batch_size_per_update = train_batch_size * gradient_accumulation_steps
1274
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
1275
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
1276
+ to_dtype = to_bf16 if training_args.mixed_precision else to_fp32
1277
+
1278
+ if training_args.do_train:
1279
+ num_train_samples = len(vectorized_datasets[data_args.train_split_name])
1280
+ steps_per_epoch = num_train_samples // batch_size_per_update
1281
+ if max_steps > 0:
1282
+ num_epochs = -(training_args.max_steps // -steps_per_epoch)
1283
+ total_train_steps = max_steps
1284
+ else:
1285
+ num_epochs = int(training_args.num_train_epochs)
1286
+ total_train_steps = steps_per_epoch * num_epochs
1287
+
1288
+ # Create learning rate schedule
1289
+ # Create learning rate schedule
1290
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
1291
+ total_train_steps,
1292
+ training_args.warmup_steps,
1293
+ training_args.learning_rate,
1294
+ )
1295
+
1296
+ # We use Optax's "masking" functionality to not apply weight decay
1297
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
1298
+ # mask boolean with the same structure as the parameters.
1299
+ # The mask is True for parameters that should be decayed.
1300
+ # Note that this mask is specifically adapted for FlaxWav2Vec2 and FlaxBart.
1301
+ # For FlaxT5, one should correct the layer norm parameter naming
1302
+ # accordingly - see `run_t5_mlm_flax.py` e.g.
1303
+ def decay_mask_fn(params):
1304
+ flat_params = traverse_util.flatten_dict(params)
1305
+ layer_norm_params = [
1306
+ (name, "scale")
1307
+ for name in ["layer_norm", "self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
1308
+ ]
1309
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
1310
+ return traverse_util.unflatten_dict(flat_mask)
1311
+
1312
+ if training_args.adafactor:
1313
+ # Create Adafactor optimizer
1314
+ optim = optax.adafactor(
1315
+ learning_rate=linear_decay_lr_schedule_fn,
1316
+ dtype_momentum=jnp.bfloat16 if training_args.mixed_precision else jnp.float32,
1317
+ weight_decay_rate=training_args.weight_decay,
1318
+ weight_decay_mask=decay_mask_fn,
1319
+ )
1320
+ else:
1321
+ # Create AdamW optimizer
1322
+ optim = optax.adamw(
1323
+ learning_rate=linear_decay_lr_schedule_fn,
1324
+ b1=training_args.adam_beta1,
1325
+ b2=training_args.adam_beta2,
1326
+ eps=training_args.adam_epsilon,
1327
+ weight_decay=training_args.weight_decay,
1328
+ mask=decay_mask_fn,
1329
+ )
1330
+
1331
+ # Optax MultiSteps for gradient accumulation. We'll only call this optimizer transformation if gradient accumulation is required (i.e. gradient accumulation steps > 1)
1332
+ if training_args.multisteps and gradient_accumulation_steps > 1:
1333
+ optim = optax.MultiSteps(optim, gradient_accumulation_steps, use_grad_mean=False)
1334
+ else:
1335
+ num_epochs = 0
1336
+ total_train_steps = 0
1337
+ num_train_samples = 0
1338
+ optim = None
1339
+
1340
+ # Setup train state
1341
+ state = MixedPrecisionTrainState.create(
1342
+ apply_fn=model.__call__,
1343
+ get_attention_mask_fn=model._get_feature_vector_attention_mask,
1344
+ params=model.params,
1345
+ tx=optim,
1346
+ to_dtype=to_dtype,
1347
+ dropout_rng=dropout_rng,
1348
+ max_grad_norm=training_args.max_grad_norm,
1349
+ )
1350
+
1351
+ # Replicate the train state on each device
1352
+ state = state.replicate()
1353
+ blank_id = model.config.pad_token_id
1354
+
1355
+ # Define gradient update step fn
1356
+ def train_step(state, batch):
1357
+ # only one single rng per grad step, with or without accumulation, as the graph should be identical over one effective training batch
1358
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
1359
+
1360
+ def compute_loss(params, minibatch):
1361
+ labels = minibatch.pop("labels")
1362
+ logits = state.apply_fn(
1363
+ **minibatch,
1364
+ params=params,
1365
+ dropout_rng=dropout_rng,
1366
+ freeze_feature_encoder=model_args.freeze_feature_encoder,
1367
+ train=True,
1368
+ )[0]
1369
+ logits_mask = state.get_attention_mask_fn(logits.shape[1], batch["attention_mask"])
1370
+ loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean")
1371
+
1372
+ return loss
1373
+
1374
+ grad_fn = jax.value_and_grad(compute_loss)
1375
+
1376
+ if gradient_accumulation_steps == 1 or training_args.multisteps:
1377
+ loss, grad = grad_fn(to_dtype(state.params), batch)
1378
+
1379
+ # Custom gradient accumulation
1380
+ else:
1381
+ # add a first dimension over gradient_accumulation_steps for minibatch slices
1382
+ batch = jax.tree_util.tree_map(
1383
+ lambda x: x.reshape(
1384
+ gradient_accumulation_steps, training_args.per_device_train_batch_size, *x.shape[1::]
1385
+ ),
1386
+ batch,
1387
+ )
1388
+
1389
+ def accum_minibatch_step(accum_grad, minibatch):
1390
+ # compute loss, num labels and grad over minibatch and accumulate
1391
+ loss, grad = grad_fn(to_dtype(state.params), minibatch)
1392
+ return jax.tree_util.tree_map(jnp.add, accum_grad, grad), loss
1393
+
1394
+ # create an initial state for accumulating losses, num labels and gradients
1395
+ init_grad = jax.tree_util.tree_map(jnp.zeros_like, to_dtype(state.params))
1396
+ # loop accum minibatch step over the number of gradient accumulation steps
1397
+ grad, loss = jax.lax.scan(accum_minibatch_step, init_grad, batch)
1398
+
1399
+ # update state
1400
+ new_state = state.apply_gradients(
1401
+ grads=grad,
1402
+ dropout_rng=new_dropout_rng,
1403
+ to_dtype=to_dtype,
1404
+ )
1405
+
1406
+ # compute gradient norms over all layers and globally for detailed monitoring
1407
+ layer_grad_norm = jax.tree_util.tree_map(jnp.linalg.norm, grad)
1408
+ logs = {
1409
+ "layer_grad_norm": layer_grad_norm,
1410
+ "grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm)),
1411
+ }
1412
+
1413
+ # compute parameter norms over all layers and globally for detailed monitoring
1414
+ layer_param_norm = jax.tree_util.tree_map(jnp.linalg.norm, new_state.params)
1415
+ logs["layer_param_norm"] = layer_param_norm
1416
+ logs["param_norm"] = jnp.linalg.norm(jax.tree_util.tree_leaves(layer_param_norm))
1417
+
1418
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
1419
+ metrics.update(logs)
1420
+
1421
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1422
+ # metrics = to_fp32(metrics)
1423
+
1424
+ return new_state, metrics
1425
+
1426
+ # Define eval fn
1427
+ def eval_step(params, batch):
1428
+ labels = batch.pop("labels")
1429
+ logits = model(**batch, params=params, train=False)[0]
1430
+
1431
+ logits_mask = model._get_feature_vector_attention_mask(logits.shape[1], batch["attention_mask"])
1432
+ loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean")
1433
+
1434
+ pred_ids = jnp.argmax(logits, axis=-1)
1435
+
1436
+ # summarize metrics
1437
+ metrics = {"loss": loss}
1438
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1439
+ # metrics = to_fp32(metrics)
1440
+ return metrics, pred_ids
1441
+
1442
+ # Create parallel version of the train and eval step
1443
+ if training_args.do_train:
1444
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
1445
+
1446
+ if training_args.do_eval:
1447
+ p_eval_step = jax.pmap(eval_step, "batch")
1448
+
1449
+ def run_evaluation(step):
1450
+ if training_args.do_eval:
1451
+ # ======================== Evaluating ==============================
1452
+ eval_metrics = []
1453
+ eval_preds = []
1454
+ eval_labels = []
1455
+
1456
+ # Generate eval set by sequentially sampling indices from the eval dataset and grouping by length
1457
+ eval_samples_idx = get_grouped_indices(vectorized_datasets[data_args.eval_split_name], eval_batch_size)
1458
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
1459
+
1460
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
1461
+ samples = [vectorized_datasets[data_args.eval_split_name][int(idx)] for idx in batch_idx]
1462
+ batch = data_collator(samples)
1463
+ labels = batch["labels"]
1464
+
1465
+ metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1466
+ eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1])))
1467
+ eval_metrics.append(metrics)
1468
+
1469
+ eval_labels.extend(labels)
1470
+
1471
+ # normalize eval metrics
1472
+ eval_metrics = get_metrics(eval_metrics)
1473
+ eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
1474
+ eval_metrics = to_fp32(eval_metrics)
1475
+
1476
+ # always run compute metrics
1477
+ error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels)
1478
+ eval_metrics.update(error_rate_metric)
1479
+ error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()])
1480
+
1481
+ # Print metrics and update progress bar
1482
+ desc = f"Step... ({step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})"
1483
+ epochs.write(desc)
1484
+ epochs.desc = desc
1485
+
1486
+ # Save metrics
1487
+ write_wandb_log(eval_metrics, step, prefix="eval")
1488
+ write_wandb_pred(pred_str, label_str, step)
1489
+ # if has_tensorboard and jax.process_index() == 0:
1490
+ # write_eval_metric(summary_writer, eval_metrics, step, pred_str=pred_str)
1491
+
1492
+ def save_checkpoint(step):
1493
+ # save and push checkpoint to the hub
1494
+ if jax.process_index() == 0:
1495
+ params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
1496
+ model.save_pretrained(training_args.output_dir, params=params)
1497
+ tokenizer.save_pretrained(training_args.output_dir)
1498
+ if training_args.push_to_hub:
1499
+ repo.push_to_hub(commit_message=f"{wandb.run.id}: saving weights and logs of step {int(step / 1000)}k", blocking=False)
1500
+
1501
+ skip_epochs = data_args.skip_steps // (num_train_samples // batch_size_per_update)
1502
+ logger.info("***** Running training *****")
1503
+ logger.info(f" Num examples = {num_train_samples}")
1504
+ logger.info(f" Num Epochs = {num_epochs}")
1505
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
1506
+ logger.info(f" Num gradient accumulation steps = {gradient_accumulation_steps}")
1507
+ logger.info(f" Total train batch size (w. parallel & distributed) = {batch_size_per_update}")
1508
+ logger.info(f" Total optimization steps = {total_train_steps}")
1509
+ logger.info(f" Gradient checkpointing: {config.gradient_checkpointing}")
1510
+ logger.info(f" Use scan: {config.use_scan}")
1511
+ logger.info(f" Fuse matmuls: {config.fuse_matmuls}")
1512
+ logger.info(f" Skipping: {data_args.skip_steps} steps ({skip_epochs} epochs)")
1513
+
1514
+ train_time = cur_step = 0
1515
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
1516
+ for epoch in epochs:
1517
+ if training_args.do_train:
1518
+ # ======================== Training ================================
1519
+ train_start = time.time()
1520
+
1521
+ if epoch < skip_epochs:
1522
+ logger.info(f"Skipping epoch {epoch + 1}")
1523
+ continue
1524
+
1525
+ # Create sampling rng
1526
+ rng, input_rng = jax.random.split(rng)
1527
+
1528
+ # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length
1529
+ train_samples_idx = get_grouped_indices(vectorized_datasets[data_args.train_split_name], batch_size_per_update, input_rng)
1530
+ train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update)
1531
+
1532
+ if data_args.skip_steps > cur_step:
1533
+ logger.info(f"Skipping {data_args.skip_steps - (epoch * (num_train_samples // batch_size_per_update))} steps...")
1534
+ # Gather the indices for creating the batch and do a training step
1535
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1), 1):
1536
+ cur_step = epoch * (num_train_samples // batch_size_per_update) + step
1537
+ if cur_step <= data_args.skip_steps:
1538
+ continue
1539
+
1540
+ samples = [vectorized_datasets[data_args.train_split_name][int(idx)] for idx in batch_idx]
1541
+ batch = data_collator(samples)
1542
+ batch = shard(batch.data)
1543
+ try:
1544
+ state, train_metric = p_train_step(state, batch)
1545
+ except TypeError as e:
1546
+ logger.warning("Encountered following error: \n", e)
1547
+
1548
+
1549
+ if cur_step % training_args.logging_steps == 0:
1550
+ # Save metrics
1551
+ train_metric = unreplicate(train_metric)
1552
+ train_time += time.time() - train_start
1553
+ # need to upcast all device arrays to fp32 for wandb logging (jnp.bfloat16 not supported) -> do this here OR in train_step
1554
+ write_wandb_log(to_fp32(train_metric), cur_step, prefix=data_args.train_split_name)
1555
+ # we won't log to tensorboard for now (it is fiddly logging param and grad norms on a layer-by-layer basis)
1556
+ # if has_tensorboard and jax.process_index() == 0:
1557
+ # write_train_metric(summary_writer, train_metrics, train_time, cur_step)
1558
+
1559
+ epochs.write(
1560
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']}, Gradient Norm: {train_metric['grad_norm']})"
1561
+ )
1562
+
1563
+ if cur_step % total_train_steps == 0:
1564
+ break
1565
+
1566
+ if training_args.eval_steps and cur_step % training_args.eval_steps == 0:
1567
+ run_evaluation(cur_step)
1568
+
1569
+ if cur_step % training_args.save_steps == 0:
1570
+ save_checkpoint(cur_step)
1571
+
1572
+ if training_args.eval_steps == 0 and (epoch + 1) != num_epochs:
1573
+ # run evaluation at the end of the epoch if eval steps are not specified
1574
+ run_evaluation(cur_step)
1575
+ save_checkpoint(cur_step)
1576
+
1577
+ if training_args.do_train:
1578
+ save_checkpoint(cur_step)
1579
+
1580
+ cur_step = max_steps if max_steps > 0 else cur_step # set step to max steps so that eval happens in alignment with training
1581
+
1582
+ if training_args.do_eval:
1583
+ run_evaluation(cur_step)
1584
+
1585
+ # TODO: collapse 'do_predict' into the run_evaluation function
1586
+ if training_args.do_predict:
1587
+ for split in [data_args.test_split_name]:
1588
+ # ======================== Evaluating ==============================
1589
+ eval_metrics = []
1590
+ eval_preds = []
1591
+ eval_labels = []
1592
+
1593
+ # Generate eval set by sequentially sampling indices from the test dataset and grouping by length
1594
+ eval_samples_idx = get_grouped_indices(vectorized_datasets[split], eval_batch_size)
1595
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
1596
+
1597
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc=f"Predicting {split}...", position=2)):
1598
+ samples = [vectorized_datasets[split][int(idx)] for idx in batch_idx]
1599
+ batch = data_collator(samples)
1600
+ labels = batch["labels"]
1601
+
1602
+ metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1603
+ eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1])))
1604
+ eval_metrics.append(metrics)
1605
+
1606
+ eval_labels.extend(labels)
1607
+
1608
+ # normalize eval metrics
1609
+ eval_metrics = get_metrics(eval_metrics)
1610
+ eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
1611
+ eval_metrics = to_fp32(eval_metrics)
1612
+
1613
+ # always run compute metrics
1614
+ error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels)
1615
+ eval_metrics.update(error_rate_metric)
1616
+ error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()])
1617
+
1618
+ # Print metrics and update progress bar
1619
+ desc = f"Step... ({cur_step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})"
1620
+ epochs.write(desc)
1621
+ epochs.desc = desc
1622
+
1623
+ # Save metrics
1624
+ write_wandb_log(eval_metrics, cur_step, prefix=split)
1625
+ write_wandb_pred(pred_str, label_str, cur_step, prefix=split)
1626
+ # if has_tensorboard and jax.process_index() == 0:
1627
+ # write_eval_metric(summary_writer, eval_metrics, cur_step, pred_str=pred_str)
1628
+
1629
+
1630
+ if __name__ == "__main__":
1631
+ main()
wandb/run-20220805_124834-3ep5xqhh/files/config.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_version: 1
2
+
3
+ _wandb:
4
+ desc: null
5
+ value:
6
+ cli_version: 0.12.9
7
+ code_path: code/run_flax_speech_recognition_ctc.py
8
+ framework: huggingface
9
+ huggingface_version: 4.21.0
10
+ is_jupyter_run: false
11
+ is_kaggle_kernel: false
12
+ python_version: 3.8.10
13
+ start_time: 1659703714
14
+ t:
15
+ 1:
16
+ - 1
17
+ - 2
18
+ - 3
19
+ - 11
20
+ - 12
21
+ 2:
22
+ - 1
23
+ - 2
24
+ - 3
25
+ - 11
26
+ - 12
27
+ 3:
28
+ - 13
29
+ 4: 3.8.10
30
+ 5: 0.12.9
31
+ 6: 4.21.0
32
+ 8:
33
+ - 5
wandb/run-20220805_124834-3ep5xqhh/files/diff.patch ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/config.json b/config.json
2
+ index 260219f..246b797 100644
3
+ --- a/config.json
4
+ +++ b/config.json
5
+ @@ -5,7 +5,7 @@
6
+ "add_adapter": false,
7
+ "apply_spec_augment": true,
8
+ "architectures": [
9
+ - "Wav2Vec2ForCTC"
10
+ + "Wav2Vec2ForPreTraining"
11
+ ],
12
+ "attention_dropout": 0.094,
13
+ "bos_token_id": 1,
14
+ diff --git a/run.sh b/run.sh
15
+ index 9cc498e..5068965 100755
16
+ --- a/run.sh
17
+ +++ b/run.sh
18
+ @@ -1,12 +1,12 @@
19
+ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_ctc.py \
20
+ --model_name_or_path="facebook/wav2vec2-xls-r-1b" \
21
+ - --hub_model_id="NbAiLab/wav2vec2-1b-npsc-nst" \
22
+ + --hub_model_id="NbAiLab/wav2vec2-1b-npsc-nst-tpu" \
23
+ --tokenizer_name="./" \
24
+ --output_dir="./" \
25
+ --overwrite_output_dir \
26
+ --num_train_epochs="40" \
27
+ - --per_device_train_batch_size="2" \
28
+ - --per_device_eval_batch_size="2" \
29
+ + --per_device_train_batch_size="4" \
30
+ + --per_device_eval_batch_size="4" \
31
+ --gradient_accumulation_steps="1" \
32
+ --precision="full_mixed" \
33
+ --matmul_precision="bfloat16" \
34
+ @@ -16,8 +16,8 @@ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_c
35
+ --length_column_name="input_length" \
36
+ --evaluation_strategy="steps" \
37
+ --text_column_name="text" \
38
+ - --save_steps="4000" \
39
+ - --eval_steps="4000" \
40
+ + --save_steps="5000" \
41
+ + --eval_steps="5000" \
42
+ --logging_steps="100" \
43
+ --layerdrop="0.041" \
44
+ --attention_dropout="0.094" \
45
+ @@ -42,7 +42,7 @@ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_c
46
+ --ctc_zero_infinity \
47
+ --do_lower_case \
48
+ --wandb_project="wav2vec2" \
49
+ - --wandb_name="wav2vec2-1b-npsc-nst" \
50
+ + --wandb_name="wav2vec2-1b-npsc-nst-tpu" \
51
+ --remove_punctuation
52
+
53
+
54
+ diff --git a/special_tokens_map.json b/special_tokens_map.json
55
+ index 89389bf..c3eacb7 100644
56
+ --- a/special_tokens_map.json
57
+ +++ b/special_tokens_map.json
58
+ @@ -343,6 +343,20 @@
59
+ "rstrip": false,
60
+ "single_word": false
61
+ },
62
+ + {
63
+ + "content": "</s>",
64
+ + "lstrip": false,
65
+ + "normalized": true,
66
+ + "rstrip": false,
67
+ + "single_word": false
68
+ + },
69
+ + {
70
+ + "content": "<s>",
71
+ + "lstrip": false,
72
+ + "normalized": true,
73
+ + "rstrip": false,
74
+ + "single_word": false
75
+ + },
76
+ {
77
+ "content": "</s>",
78
+ "lstrip": false,
79
+ diff --git a/wandb/debug-internal.log b/wandb/debug-internal.log
80
+ index 50a0b69..d716bbe 120000
81
+ --- a/wandb/debug-internal.log
82
+ +++ b/wandb/debug-internal.log
83
+ @@ -1 +1 @@
84
+ -run-20220803_091109-yit1e59z/logs/debug-internal.log
85
+
86
+ +run-20220805_124834-3ep5xqhh/logs/debug-internal.log
87
+
88
+ diff --git a/wandb/debug.log b/wandb/debug.log
89
+ index 746223d..92d0ec0 120000
90
+ --- a/wandb/debug.log
91
+ +++ b/wandb/debug.log
92
+ @@ -1 +1 @@
93
+ -run-20220803_091109-yit1e59z/logs/debug.log
94
+
95
+ +run-20220805_124834-3ep5xqhh/logs/debug.log
96
+
97
+ diff --git a/wandb/latest-run b/wandb/latest-run
98
+ index be58b40..075d875 120000
99
+ --- a/wandb/latest-run
100
+ +++ b/wandb/latest-run
101
+ @@ -1 +1 @@
102
+ -run-20220803_091109-yit1e59z
103
+
104
+ +run-20220805_124834-3ep5xqhh
105
+
wandb/run-20220805_124834-3ep5xqhh/files/output.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de21dd0aba5898c5acdb62ceee14797ef9d861ef1edb14f2c99b96ab509c9ccf
3
+ size 237607
wandb/run-20220805_124834-3ep5xqhh/files/requirements.txt ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.0.0
2
+ aiohttp==3.8.1
3
+ aiosignal==1.2.0
4
+ appdirs==1.4.4
5
+ astunparse==1.6.3
6
+ async-timeout==4.0.2
7
+ attrs==21.4.0
8
+ audioread==2.1.9
9
+ backcall==0.2.0
10
+ cachetools==4.2.4
11
+ certifi==2021.10.8
12
+ cffi==1.15.1
13
+ charset-normalizer==2.0.10
14
+ chex==0.1.3
15
+ click==8.0.3
16
+ cloud-tpu-client==0.10
17
+ cloud-tpu-profiler==2.4.0
18
+ clu==0.0.6
19
+ colorama==0.4.5
20
+ commonmark==0.9.1
21
+ configparser==5.2.0
22
+ contextlib2==21.6.0
23
+ cycler==0.11.0
24
+ datasets==2.4.0
25
+ decorator==5.1.0
26
+ dill==0.3.4
27
+ dm-tree==0.1.6
28
+ docker-pycreds==0.4.0
29
+ etils==0.6.0
30
+ exceptiongroup==1.0.0rc8
31
+ filelock==3.4.2
32
+ flatbuffers==2.0
33
+ flax==0.5.3
34
+ fonttools==4.28.5
35
+ frozenlist==1.2.0
36
+ fsspec==2021.11.1
37
+ future==0.18.2
38
+ gast==0.4.0
39
+ gitdb==4.0.9
40
+ gitpython==3.1.26
41
+ google-api-core==1.31.5
42
+ google-api-python-client==1.8.0
43
+ google-auth-httplib2==0.1.0
44
+ google-auth-oauthlib==0.4.6
45
+ google-auth==2.3.3
46
+ google-pasta==0.2.0
47
+ googleapis-common-protos==1.54.0
48
+ grpcio==1.43.0
49
+ h5py==3.6.0
50
+ httplib2==0.20.2
51
+ huggingface-hub==0.2.1
52
+ hypothesis==6.53.0
53
+ idna==3.3
54
+ importlib-metadata==4.10.0
55
+ importlib-resources==5.4.0
56
+ ipython==7.31.0
57
+ jax==0.3.15
58
+ jaxlib==0.3.15
59
+ jedi==0.18.1
60
+ jiwer==2.3.0
61
+ joblib==1.1.0
62
+ keras-preprocessing==1.1.2
63
+ keras==2.7.0
64
+ kiwisolver==1.3.2
65
+ libclang==12.0.0
66
+ librosa==0.9.2
67
+ libtpu-nightly==0.1.dev20220722
68
+ llvmlite==0.39.0
69
+ markdown==3.3.6
70
+ matplotlib-inline==0.1.3
71
+ matplotlib==3.5.1
72
+ ml-collections==0.1.0
73
+ msgpack==1.0.3
74
+ multidict==5.2.0
75
+ multiprocess==0.70.12.2
76
+ numba==0.56.0
77
+ numpy==1.22.0
78
+ oauth2client==4.1.3
79
+ oauthlib==3.1.1
80
+ opt-einsum==3.3.0
81
+ optax==0.1.3
82
+ packaging==21.3
83
+ pandas==1.3.5
84
+ parso==0.8.3
85
+ pathtools==0.1.2
86
+ pexpect==4.8.0
87
+ pickleshare==0.7.5
88
+ pillow==9.0.0
89
+ pip==22.2.1
90
+ pkg-resources==0.0.0
91
+ pooch==1.6.0
92
+ promise==2.3
93
+ prompt-toolkit==3.0.24
94
+ protobuf==3.19.1
95
+ psutil==5.9.0
96
+ ptyprocess==0.7.0
97
+ pyarrow==6.0.1
98
+ pyasn1-modules==0.2.8
99
+ pyasn1==0.4.8
100
+ pycparser==2.21
101
+ pyctcdecode==0.4.0
102
+ pygments==2.11.1
103
+ pygtrie==2.5.0
104
+ pyparsing==3.0.6
105
+ python-dateutil==2.8.2
106
+ python-levenshtein==0.12.2
107
+ pytz==2021.3
108
+ pyyaml==6.0
109
+ regex==2021.11.10
110
+ requests-oauthlib==1.3.0
111
+ requests==2.27.0
112
+ resampy==0.3.1
113
+ responses==0.18.0
114
+ rich==11.2.0
115
+ rsa==4.8
116
+ sacremoses==0.0.46
117
+ scikit-learn==1.1.1
118
+ scipy==1.7.3
119
+ sentry-sdk==1.5.2
120
+ setuptools==44.0.0
121
+ shortuuid==1.0.8
122
+ six==1.16.0
123
+ smmap==5.0.0
124
+ sortedcontainers==2.4.0
125
+ soundfile==0.10.3.post1
126
+ sox==1.4.1
127
+ subprocess32==3.5.4
128
+ tensorboard-data-server==0.6.1
129
+ tensorboard-plugin-wit==1.8.0
130
+ tensorboard==2.7.0
131
+ tensorflow-cpu==2.7.0
132
+ tensorflow-datasets==4.4.0
133
+ tensorflow-estimator==2.7.0
134
+ tensorflow-io-gcs-filesystem==0.23.1
135
+ tensorflow-metadata==1.5.0
136
+ tensorflow==2.7.0
137
+ tensorstore==0.1.21
138
+ termcolor==1.1.0
139
+ threadpoolctl==3.1.0
140
+ tokenizers==0.11.2
141
+ toolz==0.11.2
142
+ torch==1.12.0
143
+ torchaudio==0.12.0+cpu
144
+ tqdm==4.62.3
145
+ traitlets==5.1.1
146
+ transformers==4.21.0
147
+ typing-extensions==4.3.0
148
+ uritemplate==3.0.1
149
+ urllib3==1.26.7
150
+ wandb==0.12.9
151
+ wcwidth==0.2.5
152
+ werkzeug==2.0.2
153
+ wheel==0.37.1
154
+ wrapt==1.13.3
155
+ xxhash==2.0.2
156
+ yarl==1.7.2
157
+ yaspin==2.1.0
158
+ zipp==3.7.0
wandb/run-20220805_124834-3ep5xqhh/files/wandb-metadata.json ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.4.0-1043-gcp-x86_64-with-glibc2.29",
3
+ "python": "3.8.10",
4
+ "heartbeatAt": "2022-08-05T12:48:37.490835",
5
+ "startedAt": "2022-08-05T12:48:34.040204",
6
+ "docker": null,
7
+ "cpu_count": 96,
8
+ "cuda": null,
9
+ "args": [
10
+ "--model_name_or_path=facebook/wav2vec2-xls-r-1b",
11
+ "--hub_model_id=NbAiLab/wav2vec2-1b-npsc-nst-tpu",
12
+ "--tokenizer_name=./",
13
+ "--output_dir=./",
14
+ "--overwrite_output_dir",
15
+ "--num_train_epochs=40",
16
+ "--per_device_train_batch_size=4",
17
+ "--per_device_eval_batch_size=4",
18
+ "--gradient_accumulation_steps=1",
19
+ "--precision=full_mixed",
20
+ "--matmul_precision=bfloat16",
21
+ "--multisteps",
22
+ "--learning_rate=2e-4",
23
+ "--warmup_steps=2000",
24
+ "--length_column_name=input_length",
25
+ "--evaluation_strategy=steps",
26
+ "--text_column_name=text",
27
+ "--save_steps=5000",
28
+ "--eval_steps=5000",
29
+ "--logging_steps=100",
30
+ "--layerdrop=0.041",
31
+ "--attention_dropout=0.094",
32
+ "--activation_dropout=0.055",
33
+ "--hidden_dropout=0.047",
34
+ "--save_total_limit=5",
35
+ "--freeze_feature_encoder",
36
+ "--feat_proj_dropout=0.04",
37
+ "--mask_time_prob=0.082",
38
+ "--mask_time_length=10",
39
+ "--mask_feature_prob=0.25",
40
+ "--mask_feature_length=64",
41
+ "--gradient_checkpointing",
42
+ "--min_duration_in_seconds=0.5",
43
+ "--max_duration_in_seconds=30.0",
44
+ "--use_auth_token",
45
+ "--seed=42",
46
+ "--group_by_length",
47
+ "--do_train",
48
+ "--do_eval",
49
+ "--push_to_hub",
50
+ "--preprocessing_num_workers=32",
51
+ "--ctc_zero_infinity",
52
+ "--do_lower_case",
53
+ "--wandb_project=wav2vec2",
54
+ "--wandb_name=wav2vec2-1b-npsc-nst-tpu",
55
+ "--remove_punctuation"
56
+ ],
57
+ "state": "running",
58
+ "program": "run_flax_speech_recognition_ctc.py",
59
+ "codePath": "run_flax_speech_recognition_ctc.py",
60
+ "git": {
61
+ "remote": "https://huggingface.co/NbAiLab/wav2vec2-1b-npsc-nst-tpu",
62
+ "commit": "e2b1320cc68c3ce129a1d654965e0d3eb44e0558"
63
+ },
64
+ "email": "versae@gmail.com",
65
+ "root": "/data/wav2vec2-1b-npsc-nst",
66
+ "host": "t1v-n-eedfb410-w-0",
67
+ "username": "javierr",
68
+ "executable": "/data/flax/bin/python"
69
+ }
wandb/run-20220805_124834-3ep5xqhh/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"train/grad_norm": 7.09375, "layer_grad_norm/": {"lm_head": {"bias": 0.1953125, "kernel": 2.984375}, "wav2vec2": {"encoder": {"layer_norm": {"bias": 0.1513671875, "scale": 0.1123046875}, "layers": {"0": {"attention": {"k_proj": {"bias": 0.0001430511474609375, "kernel": 0.1162109375}, "out_proj": {"bias": 0.119140625, "kernel": 0.74609375}, "q_proj": {"bias": 0.015380859375, "kernel": 0.1611328125}, "v_proj": {"bias": 0.11328125, "kernel": 0.859375}}, "feed_forward": {"intermediate_dense": {"bias": 0.08544921875, "kernel": 1.296875}, "output_dense": {"bias": 0.026123046875, "kernel": 1.140625}}, "final_layer_norm": {"bias": 0.267578125, "scale": 0.4375}, "layer_norm": {"bias": 0.208984375, "scale": 0.330078125}}, "1": {"attention": {"k_proj": {"bias": 0.00010538101196289062, "kernel": 0.1025390625}, "out_proj": {"bias": 0.0284423828125, "kernel": 0.431640625}, "q_proj": {"bias": 0.008056640625, "kernel": 0.10400390625}, "v_proj": {"bias": 0.04248046875, "kernel": 0.31640625}}, "feed_forward": {"intermediate_dense": {"bias": 0.040283203125, "kernel": 0.6328125}, "output_dense": {"bias": 0.0263671875, "kernel": 0.55859375}}, "final_layer_norm": {"bias": 0.07666015625, "scale": 0.078125}, "layer_norm": {"bias": 0.07958984375, "scale": 0.111328125}}, "10": {"attention": {"k_proj": {"bias": 7.200241088867188e-05, "kernel": 0.185546875}, "out_proj": {"bias": 0.017333984375, "kernel": 0.30859375}, "q_proj": {"bias": 0.01123046875, "kernel": 0.203125}, "v_proj": {"bias": 0.0263671875, "kernel": 0.296875}}, "feed_forward": {"intermediate_dense": {"bias": 0.0281982421875, "kernel": 0.4921875}, "output_dense": {"bias": 0.01611328125, "kernel": 0.36328125}}, "final_layer_norm": {"bias": 0.0498046875, "scale": 0.047119140625}, "layer_norm": {"bias": 0.055419921875, "scale": 0.05712890625}}, "11": {"attention": {"k_proj": {"bias": 0.00011920928955078125, "kernel": 0.25}, "out_proj": {"bias": 0.0166015625, "kernel": 0.3984375}, "q_proj": {"bias": 0.01312255859375, "kernel": 0.232421875}, "v_proj": {"bias": 0.027099609375, "kernel": 0.390625}}, "feed_forward": {"intermediate_dense": {"bias": 0.025390625, "kernel": 0.451171875}, "output_dense": {"bias": 0.0159912109375, "kernel": 0.31640625}}, "final_layer_norm": {"bias": 0.04296875, "scale": 0.041015625}, "layer_norm": {"bias": 0.05322265625, "scale": 0.072265625}}, "12": {"attention": {"k_proj": {"bias": 6.866455078125e-05, "kernel": 0.181640625}, "out_proj": {"bias": 0.0162353515625, "kernel": 0.3046875}, "q_proj": {"bias": 0.00927734375, "kernel": 0.17578125}, "v_proj": {"bias": 0.02490234375, "kernel": 0.3046875}}, "feed_forward": {"intermediate_dense": {"bias": 0.026123046875, "kernel": 0.44921875}, "output_dense": {"bias": 0.015380859375, "kernel": 0.30859375}}, "final_layer_norm": {"bias": 0.043701171875, "scale": 0.04296875}, "layer_norm": {"bias": 0.044921875, "scale": 0.056396484375}}, "13": {"attention": {"k_proj": {"bias": 0.0001392364501953125, "kernel": 0.22265625}, "out_proj": {"bias": 0.0162353515625, "kernel": 0.392578125}, "q_proj": {"bias": 0.01123046875, "kernel": 0.216796875}, "v_proj": {"bias": 0.027099609375, "kernel": 0.4140625}}, "feed_forward": {"intermediate_dense": {"bias": 0.024658203125, "kernel": 0.421875}, "output_dense": {"bias": 0.0155029296875, "kernel": 0.30078125}}, "final_layer_norm": {"bias": 0.043212890625, "scale": 0.045166015625}, "layer_norm": {"bias": 0.04443359375, "scale": 0.07470703125}}, "14": {"attention": {"k_proj": {"bias": 0.0002117156982421875, "kernel": 0.146484375}, "out_proj": {"bias": 0.01556396484375, "kernel": 0.3359375}, "q_proj": {"bias": 0.0072021484375, "kernel": 0.142578125}, "v_proj": {"bias": 0.0238037109375, "kernel": 0.34765625}}, "feed_forward": {"intermediate_dense": {"bias": 0.024658203125, "kernel": 0.419921875}, "output_dense": {"bias": 0.01513671875, "kernel": 0.3046875}}, "final_layer_norm": {"bias": 0.04345703125, "scale": 0.04638671875}, "layer_norm": {"bias": 0.038330078125, "scale": 0.04052734375}}, "15": {"attention": {"k_proj": {"bias": 0.0002269744873046875, "kernel": 0.2421875}, "out_proj": {"bias": 0.014892578125, "kernel": 0.4140625}, "q_proj": {"bias": 0.01318359375, "kernel": 0.228515625}, "v_proj": {"bias": 0.024658203125, "kernel": 0.376953125}}, "feed_forward": {"intermediate_dense": {"bias": 0.021240234375, "kernel": 0.359375}, "output_dense": {"bias": 0.0146484375, "kernel": 0.283203125}}, "final_layer_norm": {"bias": 0.036376953125, "scale": 0.03466796875}, "layer_norm": {"bias": 0.0439453125, "scale": 0.0634765625}}, "16": {"attention": {"k_proj": {"bias": 9.34600830078125e-05, "kernel": 0.15234375}, "out_proj": {"bias": 0.01470947265625, "kernel": 0.283203125}, "q_proj": {"bias": 0.00775146484375, "kernel": 0.138671875}, "v_proj": {"bias": 0.021728515625, "kernel": 0.28125}}, "feed_forward": {"intermediate_dense": {"bias": 0.0205078125, "kernel": 0.35546875}, "output_dense": {"bias": 0.0142822265625, "kernel": 0.2734375}}, "final_layer_norm": {"bias": 0.0341796875, "scale": 0.02978515625}, "layer_norm": {"bias": 0.037353515625, "scale": 0.03955078125}}, "17": {"attention": {"k_proj": {"bias": 5.435943603515625e-05, "kernel": 0.1240234375}, "out_proj": {"bias": 0.0152587890625, "kernel": 0.2265625}, "q_proj": {"bias": 0.00628662109375, "kernel": 0.1181640625}, "v_proj": {"bias": 0.021484375, "kernel": 0.23828125}}, "feed_forward": {"intermediate_dense": {"bias": 0.0206298828125, "kernel": 0.353515625}, "output_dense": {"bias": 0.0146484375, "kernel": 0.271484375}}, "final_layer_norm": {"bias": 0.035400390625, "scale": 0.0308837890625}, "layer_norm": {"bias": 0.036865234375, "scale": 0.0400390625}}, "18": {"attention": {"k_proj": {"bias": 0.0001049041748046875, "kernel": 0.1767578125}, "out_proj": {"bias": 0.01470947265625, "kernel": 0.30078125}, "q_proj": {"bias": 0.00860595703125, "kernel": 0.158203125}, "v_proj": {"bias": 0.02197265625, "kernel": 0.28515625}}, "feed_forward": {"intermediate_dense": {"bias": 0.019775390625, "kernel": 0.33984375}, "output_dense": {"bias": 0.01409912109375, "kernel": 0.265625}}, "final_layer_norm": {"bias": 0.0341796875, "scale": 0.030517578125}, "layer_norm": {"bias": 0.0380859375, "scale": 0.0439453125}}, "19": {"attention": {"k_proj": {"bias": 6.580352783203125e-05, "kernel": 0.126953125}, "out_proj": {"bias": 0.01446533203125, "kernel": 0.2255859375}, "q_proj": {"bias": 0.006866455078125, "kernel": 0.1171875}, "v_proj": {"bias": 0.0201416015625, "kernel": 0.234375}}, "feed_forward": {"intermediate_dense": {"bias": 0.018798828125, "kernel": 0.33203125}, "output_dense": {"bias": 0.01416015625, "kernel": 0.263671875}}, "final_layer_norm": {"bias": 0.0306396484375, "scale": 0.02734375}, "layer_norm": {"bias": 0.03173828125, "scale": 0.02734375}}, "2": {"attention": {"k_proj": {"bias": 0.00013065338134765625, "kernel": 0.15625}, "out_proj": {"bias": 0.0289306640625, "kernel": 0.48828125}, "q_proj": {"bias": 0.010986328125, "kernel": 0.154296875}, "v_proj": {"bias": 0.04638671875, "kernel": 0.421875}}, "feed_forward": {"intermediate_dense": {"bias": 0.04296875, "kernel": 0.75390625}, "output_dense": {"bias": 0.02587890625, "kernel": 0.56640625}}, "final_layer_norm": {"bias": 0.0791015625, "scale": 0.076171875}, "layer_norm": {"bias": 0.07763671875, "scale": 0.154296875}}, "20": {"attention": {"k_proj": {"bias": 1.9788742065429688e-05, "kernel": 0.0703125}, "out_proj": {"bias": 0.014892578125, "kernel": 0.146484375}, "q_proj": {"bias": 0.003662109375, "kernel": 0.0673828125}, "v_proj": {"bias": 0.019775390625, "kernel": 0.162109375}}, "feed_forward": {"intermediate_dense": {"bias": 0.01953125, "kernel": 0.34765625}, "output_dense": {"bias": 0.01470947265625, "kernel": 0.26171875}}, "final_layer_norm": {"bias": 0.032470703125, "scale": 0.0296630859375}, "layer_norm": {"bias": 0.0308837890625, "scale": 0.0230712890625}}, "21": {"attention": {"k_proj": {"bias": 6.151199340820312e-05, "kernel": 0.10546875}, "out_proj": {"bias": 0.014892578125, "kernel": 0.224609375}, "q_proj": {"bias": 0.00518798828125, "kernel": 0.09521484375}, "v_proj": {"bias": 0.021240234375, "kernel": 0.24609375}}, "feed_forward": {"intermediate_dense": {"bias": 0.0196533203125, "kernel": 0.353515625}, "output_dense": {"bias": 0.0147705078125, "kernel": 0.263671875}}, "final_layer_norm": {"bias": 0.03271484375, "scale": 0.03369140625}, "layer_norm": {"bias": 0.03076171875, "scale": 0.0322265625}}, "22": {"attention": {"k_proj": {"bias": 3.409385681152344e-05, "kernel": 0.1044921875}, "out_proj": {"bias": 0.015625, "kernel": 0.19921875}, "q_proj": {"bias": 0.00518798828125, "kernel": 0.10205078125}, "v_proj": {"bias": 0.021484375, "kernel": 0.20703125}}, "feed_forward": {"intermediate_dense": {"bias": 0.02001953125, "kernel": 0.353515625}, "output_dense": {"bias": 0.01544189453125, "kernel": 0.26171875}}, "final_layer_norm": {"bias": 0.033447265625, "scale": 0.03857421875}, "layer_norm": {"bias": 0.03515625, "scale": 0.04052734375}}, "23": {"attention": {"k_proj": {"bias": 0.0001430511474609375, "kernel": 0.158203125}, "out_proj": {"bias": 0.015869140625, "kernel": 0.34765625}, "q_proj": {"bias": 0.00762939453125, "kernel": 0.14453125}, "v_proj": {"bias": 0.02392578125, "kernel": 0.32421875}}, "feed_forward": {"intermediate_dense": {"bias": 0.01953125, "kernel": 0.34375}, "output_dense": {"bias": 0.015625, "kernel": 0.25390625}}, "final_layer_norm": {"bias": 0.0322265625, "scale": 0.034912109375}, "layer_norm": {"bias": 0.03759765625, "scale": 0.033203125}}, "24": {"attention": {"k_proj": {"bias": 8.58306884765625e-05, "kernel": 0.150390625}, "out_proj": {"bias": 0.0147705078125, "kernel": 0.26953125}, "q_proj": {"bias": 0.0072021484375, "kernel": 0.1337890625}, "v_proj": {"bias": 0.0234375, "kernel": 0.27734375}}, "feed_forward": {"intermediate_dense": {"bias": 0.0196533203125, "kernel": 0.353515625}, "output_dense": {"bias": 0.014404296875, "kernel": 0.240234375}}, "final_layer_norm": {"bias": 0.03466796875, "scale": 0.03466796875}, "layer_norm": {"bias": 0.041015625, "scale": 0.03369140625}}, "25": {"attention": {"k_proj": {"bias": 9.34600830078125e-05, "kernel": 0.109375}, "out_proj": {"bias": 0.0147705078125, "kernel": 0.251953125}, "q_proj": {"bias": 0.0057373046875, "kernel": 0.1044921875}, "v_proj": {"bias": 0.0224609375, "kernel": 0.26171875}}, "feed_forward": {"intermediate_dense": {"bias": 0.019287109375, "kernel": 0.35546875}, "output_dense": {"bias": 0.014404296875, "kernel": 0.236328125}}, "final_layer_norm": {"bias": 0.033203125, "scale": 0.034912109375}, "layer_norm": {"bias": 0.03466796875, "scale": 0.029296875}}, "26": {"attention": {"k_proj": {"bias": 6.29425048828125e-05, "kernel": 0.12451171875}, "out_proj": {"bias": 0.0145263671875, "kernel": 0.248046875}, "q_proj": {"bias": 0.0064697265625, "kernel": 0.12451171875}, "v_proj": {"bias": 0.022705078125, "kernel": 0.2578125}}, "feed_forward": {"intermediate_dense": {"bias": 0.018798828125, "kernel": 0.32421875}, "output_dense": {"bias": 0.01422119140625, "kernel": 0.23046875}}, "final_layer_norm": {"bias": 0.031982421875, "scale": 0.03173828125}, "layer_norm": {"bias": 0.03564453125, "scale": 0.0296630859375}}, "27": {"attention": {"k_proj": {"bias": 0.000133514404296875, "kernel": 0.14453125}, "out_proj": {"bias": 0.01348876953125, "kernel": 0.30859375}, "q_proj": {"bias": 0.0072021484375, "kernel": 0.1328125}, "v_proj": {"bias": 0.0224609375, "kernel": 0.306640625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0184326171875, "kernel": 0.30859375}, "output_dense": {"bias": 0.0135498046875, "kernel": 0.228515625}}, "final_layer_norm": {"bias": 0.0322265625, "scale": 0.0361328125}, "layer_norm": {"bias": 0.039306640625, "scale": 0.03369140625}}, "28": {"attention": {"k_proj": {"bias": 0.00012063980102539062, "kernel": 0.130859375}, "out_proj": {"bias": 0.0128173828125, "kernel": 0.291015625}, "q_proj": {"bias": 0.00640869140625, "kernel": 0.12890625}, "v_proj": {"bias": 0.0213623046875, "kernel": 0.298828125}}, "feed_forward": {"intermediate_dense": {"bias": 0.01708984375, "kernel": 0.29296875}, "output_dense": {"bias": 0.012939453125, "kernel": 0.21875}}, "final_layer_norm": {"bias": 0.0299072265625, "scale": 0.03125}, "layer_norm": {"bias": 0.0361328125, "scale": 0.044921875}}, "29": {"attention": {"k_proj": {"bias": 9.822845458984375e-05, "kernel": 0.134765625}, "out_proj": {"bias": 0.0123291015625, "kernel": 0.24609375}, "q_proj": {"bias": 0.006439208984375, "kernel": 0.1279296875}, "v_proj": {"bias": 0.019775390625, "kernel": 0.251953125}}, "feed_forward": {"intermediate_dense": {"bias": 0.0174560546875, "kernel": 0.32421875}, "output_dense": {"bias": 0.0120849609375, "kernel": 0.2197265625}}, "final_layer_norm": {"bias": 0.028076171875, "scale": 0.0302734375}, "layer_norm": {"bias": 0.037841796875, "scale": 0.031982421875}}, "3": {"attention": {"k_proj": {"bias": 0.00019168853759765625, "kernel": 0.2353515625}, "out_proj": {"bias": 0.027099609375, "kernel": 0.62890625}, "q_proj": {"bias": 0.01446533203125, "kernel": 0.228515625}, "v_proj": {"bias": 0.04345703125, "kernel": 0.5546875}}, "feed_forward": {"intermediate_dense": {"bias": 0.04248046875, "kernel": 0.7421875}, "output_dense": {"bias": 0.02490234375, "kernel": 0.5390625}}, "final_layer_norm": {"bias": 0.078125, "scale": 0.0732421875}, "layer_norm": {"bias": 0.07177734375, "scale": 0.076171875}}, "30": {"attention": {"k_proj": {"bias": 0.0001049041748046875, "kernel": 0.15625}, "out_proj": {"bias": 0.0118408203125, "kernel": 0.275390625}, "q_proj": {"bias": 0.007781982421875, "kernel": 0.15625}, "v_proj": {"bias": 0.01953125, "kernel": 0.2890625}}, "feed_forward": {"intermediate_dense": {"bias": 0.016845703125, "kernel": 0.3203125}, "output_dense": {"bias": 0.0115966796875, "kernel": 0.203125}}, "final_layer_norm": {"bias": 0.028076171875, "scale": 0.02880859375}, "layer_norm": {"bias": 0.0303955078125, "scale": 0.042724609375}}, "31": {"attention": {"k_proj": {"bias": 0.00010251998901367188, "kernel": 0.14453125}, "out_proj": {"bias": 0.01104736328125, "kernel": 0.236328125}, "q_proj": {"bias": 0.00714111328125, "kernel": 0.1416015625}, "v_proj": {"bias": 0.0181884765625, "kernel": 0.25390625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0159912109375, "kernel": 0.3046875}, "output_dense": {"bias": 0.0107421875, "kernel": 0.193359375}}, "final_layer_norm": {"bias": 0.0263671875, "scale": 0.025146484375}, "layer_norm": {"bias": 0.030517578125, "scale": 0.04638671875}}, "32": {"attention": {"k_proj": {"bias": 0.0001010894775390625, "kernel": 0.13671875}, "out_proj": {"bias": 0.0103759765625, "kernel": 0.21484375}, "q_proj": {"bias": 0.006561279296875, "kernel": 0.13671875}, "v_proj": {"bias": 0.0166015625, "kernel": 0.244140625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0155029296875, "kernel": 0.30078125}, "output_dense": {"bias": 0.0096435546875, "kernel": 0.1806640625}}, "final_layer_norm": {"bias": 0.0283203125, "scale": 0.02734375}, "layer_norm": {"bias": 0.0281982421875, "scale": 0.0283203125}}, "33": {"attention": {"k_proj": {"bias": 0.00011444091796875, "kernel": 0.16015625}, "out_proj": {"bias": 0.0096435546875, "kernel": 0.2109375}, "q_proj": {"bias": 0.00775146484375, "kernel": 0.158203125}, "v_proj": {"bias": 0.0159912109375, "kernel": 0.232421875}}, "feed_forward": {"intermediate_dense": {"bias": 0.014404296875, "kernel": 0.27734375}, "output_dense": {"bias": 0.0091552734375, "kernel": 0.1796875}}, "final_layer_norm": {"bias": 0.0274658203125, "scale": 0.028564453125}, "layer_norm": {"bias": 0.0263671875, "scale": 0.033447265625}}, "34": {"attention": {"k_proj": {"bias": 0.0001583099365234375, "kernel": 0.15234375}, "out_proj": {"bias": 0.0086669921875, "kernel": 0.21484375}, "q_proj": {"bias": 0.00689697265625, "kernel": 0.14453125}, "v_proj": {"bias": 0.014892578125, "kernel": 0.234375}}, "feed_forward": {"intermediate_dense": {"bias": 0.0133056640625, "kernel": 0.251953125}, "output_dense": {"bias": 0.0081787109375, "kernel": 0.171875}}, "final_layer_norm": {"bias": 0.025146484375, "scale": 0.031494140625}, "layer_norm": {"bias": 0.026123046875, "scale": 0.030517578125}}, "35": {"attention": {"k_proj": {"bias": 0.00021076202392578125, "kernel": 0.134765625}, "out_proj": {"bias": 0.007598876953125, "kernel": 0.21484375}, "q_proj": {"bias": 0.006011962890625, "kernel": 0.1318359375}, "v_proj": {"bias": 0.0113525390625, "kernel": 0.1982421875}}, "feed_forward": {"intermediate_dense": {"bias": 0.01068115234375, "kernel": 0.201171875}, "output_dense": {"bias": 0.0074462890625, "kernel": 0.1494140625}}, "final_layer_norm": {"bias": 0.01953125, "scale": 0.0235595703125}, "layer_norm": {"bias": 0.0245361328125, "scale": 0.03076171875}}, "36": {"attention": {"k_proj": {"bias": 0.000156402587890625, "kernel": 0.126953125}, "out_proj": {"bias": 0.00714111328125, "kernel": 0.1630859375}, "q_proj": {"bias": 0.006072998046875, "kernel": 0.1279296875}, "v_proj": {"bias": 0.01019287109375, "kernel": 0.15234375}}, "feed_forward": {"intermediate_dense": {"bias": 0.009765625, "kernel": 0.181640625}, "output_dense": {"bias": 0.006927490234375, "kernel": 0.1337890625}}, "final_layer_norm": {"bias": 0.017333984375, "scale": 0.0185546875}, "layer_norm": {"bias": 0.021240234375, "scale": 0.02001953125}}, "37": {"attention": {"k_proj": {"bias": 0.00011348724365234375, "kernel": 0.126953125}, "out_proj": {"bias": 0.006622314453125, "kernel": 0.1650390625}, "q_proj": {"bias": 0.0074462890625, "kernel": 0.1298828125}, "v_proj": {"bias": 0.010498046875, "kernel": 0.16015625}}, "feed_forward": {"intermediate_dense": {"bias": 0.009521484375, "kernel": 0.177734375}, "output_dense": {"bias": 0.0064697265625, "kernel": 0.126953125}}, "final_layer_norm": {"bias": 0.017333984375, "scale": 0.01611328125}, "layer_norm": {"bias": 0.02587890625, "scale": 0.03857421875}}, "38": {"attention": {"k_proj": {"bias": 0.00012683868408203125, "kernel": 0.1201171875}, "out_proj": {"bias": 0.0062255859375, "kernel": 0.14453125}, "q_proj": {"bias": 0.005462646484375, "kernel": 0.11474609375}, "v_proj": {"bias": 0.00982666015625, "kernel": 0.15234375}}, "feed_forward": {"intermediate_dense": {"bias": 0.00885009765625, "kernel": 0.1611328125}, "output_dense": {"bias": 0.00604248046875, "kernel": 0.12060546875}}, "final_layer_norm": {"bias": 0.0166015625, "scale": 0.016845703125}, "layer_norm": {"bias": 0.0224609375, "scale": 0.021484375}}, "39": {"attention": {"k_proj": {"bias": 9.1552734375e-05, "kernel": 0.103515625}, "out_proj": {"bias": 0.00567626953125, "kernel": 0.150390625}, "q_proj": {"bias": 0.00439453125, "kernel": 0.1005859375}, "v_proj": {"bias": 0.0091552734375, "kernel": 0.1494140625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0078125, "kernel": 0.146484375}, "output_dense": {"bias": 0.0054931640625, "kernel": 0.119140625}}, "final_layer_norm": {"bias": 0.01531982421875, "scale": 0.02001953125}, "layer_norm": {"bias": 0.021484375, "scale": 0.02099609375}}, "4": {"attention": {"k_proj": {"bias": 0.0002593994140625, "kernel": 0.263671875}, "out_proj": {"bias": 0.0255126953125, "kernel": 0.65625}, "q_proj": {"bias": 0.01495361328125, "kernel": 0.255859375}, "v_proj": {"bias": 0.039306640625, "kernel": 0.58984375}}, "feed_forward": {"intermediate_dense": {"bias": 0.041015625, "kernel": 0.6875}, "output_dense": {"bias": 0.02392578125, "kernel": 0.51171875}}, "final_layer_norm": {"bias": 0.0712890625, "scale": 0.07177734375}, "layer_norm": {"bias": 0.06640625, "scale": 0.08837890625}}, "40": {"attention": {"k_proj": {"bias": 5.245208740234375e-05, "kernel": 0.0908203125}, "out_proj": {"bias": 0.005584716796875, "kernel": 0.12353515625}, "q_proj": {"bias": 0.00445556640625, "kernel": 0.0947265625}, "v_proj": {"bias": 0.00848388671875, "kernel": 0.1201171875}}, "feed_forward": {"intermediate_dense": {"bias": 0.0074462890625, "kernel": 0.1279296875}, "output_dense": {"bias": 0.005401611328125, "kernel": 0.1064453125}}, "final_layer_norm": {"bias": 0.0164794921875, "scale": 0.033203125}, "layer_norm": {"bias": 0.0196533203125, "scale": 0.01953125}}, "41": {"attention": {"k_proj": {"bias": 6.151199340820312e-05, "kernel": 0.076171875}, "out_proj": {"bias": 0.00537109375, "kernel": 0.154296875}, "q_proj": {"bias": 0.0035247802734375, "kernel": 0.0791015625}, "v_proj": {"bias": 0.0091552734375, "kernel": 0.162109375}}, "feed_forward": {"intermediate_dense": {"bias": 0.006561279296875, "kernel": 0.119140625}, "output_dense": {"bias": 0.00518798828125, "kernel": 0.10009765625}}, "final_layer_norm": {"bias": 0.01611328125, "scale": 0.0238037109375}, "layer_norm": {"bias": 0.0205078125, "scale": 0.027587890625}}, "42": {"attention": {"k_proj": {"bias": 1.3828277587890625e-05, "kernel": 0.030029296875}, "out_proj": {"bias": 0.00537109375, "kernel": 0.09033203125}, "q_proj": {"bias": 0.001678466796875, "kernel": 0.029541015625}, "v_proj": {"bias": 0.00665283203125, "kernel": 0.08349609375}}, "feed_forward": {"intermediate_dense": {"bias": 0.0059814453125, "kernel": 0.109375}, "output_dense": {"bias": 0.005279541015625, "kernel": 0.0869140625}}, "final_layer_norm": {"bias": 0.0145263671875, "scale": 0.02001953125}, "layer_norm": {"bias": 0.013427734375, "scale": 0.0194091796875}}, "43": {"attention": {"k_proj": {"bias": 4.798173904418945e-06, "kernel": 0.017578125}, "out_proj": {"bias": 0.005401611328125, "kernel": 0.06103515625}, "q_proj": {"bias": 0.0010223388671875, "kernel": 0.017822265625}, "v_proj": {"bias": 0.006103515625, "kernel": 0.0615234375}}, "feed_forward": {"intermediate_dense": {"bias": 0.005767822265625, "kernel": 0.1123046875}, "output_dense": {"bias": 0.005462646484375, "kernel": 0.0859375}}, "final_layer_norm": {"bias": 0.01220703125, "scale": 0.016845703125}, "layer_norm": {"bias": 0.01129150390625, "scale": 0.01507568359375}}, "44": {"attention": {"k_proj": {"bias": 5.8710575103759766e-06, "kernel": 0.0185546875}, "out_proj": {"bias": 0.005584716796875, "kernel": 0.0654296875}, "q_proj": {"bias": 0.00104522705078125, "kernel": 0.01904296875}, "v_proj": {"bias": 0.0062255859375, "kernel": 0.06494140625}}, "feed_forward": {"intermediate_dense": {"bias": 0.005279541015625, "kernel": 0.1142578125}, "output_dense": {"bias": 0.00567626953125, "kernel": 0.080078125}}, "final_layer_norm": {"bias": 0.0106201171875, "scale": 0.01275634765625}, "layer_norm": {"bias": 0.01318359375, "scale": 0.011962890625}}, "45": {"attention": {"k_proj": {"bias": 5.543231964111328e-06, "kernel": 0.017333984375}, "out_proj": {"bias": 0.005767822265625, "kernel": 0.06591796875}, "q_proj": {"bias": 0.001068115234375, "kernel": 0.0177001953125}, "v_proj": {"bias": 0.007110595703125, "kernel": 0.06982421875}}, "feed_forward": {"intermediate_dense": {"bias": 0.00518798828125, "kernel": 0.111328125}, "output_dense": {"bias": 0.005889892578125, "kernel": 0.080078125}}, "final_layer_norm": {"bias": 0.01031494140625, "scale": 0.0118408203125}, "layer_norm": {"bias": 0.0174560546875, "scale": 0.01611328125}}, "46": {"attention": {"k_proj": {"bias": 8.702278137207031e-06, "kernel": 0.01904296875}, "out_proj": {"bias": 0.00604248046875, "kernel": 0.0693359375}, "q_proj": {"bias": 0.001190185546875, "kernel": 0.018310546875}, "v_proj": {"bias": 0.00830078125, "kernel": 0.0810546875}}, "feed_forward": {"intermediate_dense": {"bias": 0.00494384765625, "kernel": 0.095703125}, "output_dense": {"bias": 0.00616455078125, "kernel": 0.076171875}}, "final_layer_norm": {"bias": 0.01171875, "scale": 0.0203857421875}, "layer_norm": {"bias": 0.02294921875, "scale": 0.0159912109375}}, "47": {"attention": {"k_proj": {"bias": 1.2278556823730469e-05, "kernel": 0.020263671875}, "out_proj": {"bias": 0.00634765625, "kernel": 0.06689453125}, "q_proj": {"bias": 0.00138092041015625, "kernel": 0.018798828125}, "v_proj": {"bias": 0.0098876953125, "kernel": 0.08935546875}}, "feed_forward": {"intermediate_dense": {"bias": 0.00445556640625, "kernel": 0.0703125}, "output_dense": {"bias": 0.006561279296875, "kernel": 0.072265625}}, "final_layer_norm": {"bias": 0.01416015625, "scale": 0.0301513671875}, "layer_norm": {"bias": 0.0303955078125, "scale": 0.03271484375}}, "5": {"attention": {"k_proj": {"bias": 0.00010967254638671875, "kernel": 0.228515625}, "out_proj": {"bias": 0.025390625, "kernel": 0.42578125}, "q_proj": {"bias": 0.01287841796875, "kernel": 0.224609375}, "v_proj": {"bias": 0.0400390625, "kernel": 0.41796875}}, "feed_forward": {"intermediate_dense": {"bias": 0.0400390625, "kernel": 0.6484375}, "output_dense": {"bias": 0.02392578125, "kernel": 0.48046875}}, "final_layer_norm": {"bias": 0.0712890625, "scale": 0.0703125}, "layer_norm": {"bias": 0.068359375, "scale": 0.08984375}}, "6": {"attention": {"k_proj": {"bias": 0.00018310546875, "kernel": 0.3203125}, "out_proj": {"bias": 0.023193359375, "kernel": 0.5703125}, "q_proj": {"bias": 0.01806640625, "kernel": 0.296875}, "v_proj": {"bias": 0.0419921875, "kernel": 0.578125}}, "feed_forward": {"intermediate_dense": {"bias": 0.037109375, "kernel": 0.62890625}, "output_dense": {"bias": 0.0216064453125, "kernel": 0.4375}}, "final_layer_norm": {"bias": 0.06396484375, "scale": 0.06103515625}, "layer_norm": {"bias": 0.072265625, "scale": 0.095703125}}, "7": {"attention": {"k_proj": {"bias": 0.0001773834228515625, "kernel": 0.25}, "out_proj": {"bias": 0.0218505859375, "kernel": 0.515625}, "q_proj": {"bias": 0.013916015625, "kernel": 0.240234375}, "v_proj": {"bias": 0.03662109375, "kernel": 0.5078125}}, "feed_forward": {"intermediate_dense": {"bias": 0.035400390625, "kernel": 0.625}, "output_dense": {"bias": 0.02001953125, "kernel": 0.4296875}}, "final_layer_norm": {"bias": 0.061767578125, "scale": 0.06103515625}, "layer_norm": {"bias": 0.06689453125, "scale": 0.083984375}}, "8": {"attention": {"k_proj": {"bias": 0.000164031982421875, "kernel": 0.2353515625}, "out_proj": {"bias": 0.0201416015625, "kernel": 0.45703125}, "q_proj": {"bias": 0.01220703125, "kernel": 0.22265625}, "v_proj": {"bias": 0.0322265625, "kernel": 0.43359375}}, "feed_forward": {"intermediate_dense": {"bias": 0.032470703125, "kernel": 0.57421875}, "output_dense": {"bias": 0.0185546875, "kernel": 0.41015625}}, "final_layer_norm": {"bias": 0.0576171875, "scale": 0.06298828125}, "layer_norm": {"bias": 0.05712890625, "scale": 0.078125}}, "9": {"attention": {"k_proj": {"bias": 0.000255584716796875, "kernel": 0.29296875}, "out_proj": {"bias": 0.017822265625, "kernel": 0.5859375}, "q_proj": {"bias": 0.013916015625, "kernel": 0.27734375}, "v_proj": {"bias": 0.030029296875, "kernel": 0.58203125}}, "feed_forward": {"intermediate_dense": {"bias": 0.02685546875, "kernel": 0.482421875}, "output_dense": {"bias": 0.0172119140625, "kernel": 0.3671875}}, "final_layer_norm": {"bias": 0.043701171875, "scale": 0.04638671875}, "layer_norm": {"bias": 0.0556640625, "scale": 0.06689453125}}}, "pos_conv_embed": {"conv": {"bias": 0.1005859375, "weight_g": 0.09716796875, "weight_v": 0.75}}}, "feature_extractor": {"conv_layers": {"0": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "1": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "2": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "3": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "4": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "5": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "6": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}}}, "feature_projection": {"layer_norm": {"bias": 0.349609375, "scale": 0.431640625}, "projection": {"bias": 0.162109375, "kernel": 2.96875}}, "masked_spec_embed": 0.0}}, "layer_param_norm/": {"lm_head": {"bias": 0.068152517080307, "kernel": 5.452093124389648}, "wav2vec2": {"encoder": {"layer_norm": {"bias": 0.8908292055130005, "scale": 22.55011749267578}, "layers": {"0": {"attention": {"k_proj": {"bias": 0.2033224105834961, "kernel": 26.631282806396484}, "out_proj": {"bias": 1.5543559789657593, "kernel": 25.581758499145508}, "q_proj": {"bias": 1.3539581298828125, "kernel": 26.875282287597656}, "v_proj": {"bias": 0.3526538014411926, "kernel": 26.176233291625977}}, "feed_forward": {"intermediate_dense": {"bias": 1.820350170135498, "kernel": 96.90775299072266}, "output_dense": {"bias": 1.0441815853118896, "kernel": 92.76568603515625}}, "final_layer_norm": {"bias": 1.3226207494735718, "scale": 19.917068481445312}, "layer_norm": {"bias": 3.2080445289611816, "scale": 15.740410804748535}}, "1": {"attention": {"k_proj": {"bias": 0.20793935656547546, "kernel": 40.073883056640625}, "out_proj": {"bias": 1.325477957725525, "kernel": 42.14421081542969}, "q_proj": {"bias": 2.8881866931915283, "kernel": 39.97734069824219}, "v_proj": {"bias": 0.29305994510650635, "kernel": 40.59188461303711}}, "feed_forward": {"intermediate_dense": {"bias": 1.7098056077957153, "kernel": 94.91741943359375}, "output_dense": {"bias": 0.8041924238204956, "kernel": 85.20159912109375}}, "final_layer_norm": {"bias": 1.2720082998275757, "scale": 18.5848331451416}, "layer_norm": {"bias": 1.8610038757324219, "scale": 18.784591674804688}}, "10": {"attention": {"k_proj": {"bias": 0.23071043193340302, "kernel": 47.817657470703125}, "out_proj": {"bias": 1.260962724685669, "kernel": 50.55366516113281}, "q_proj": {"bias": 2.428488254547119, "kernel": 47.79916000366211}, "v_proj": {"bias": 0.31947579979896545, "kernel": 50.702056884765625}}, "feed_forward": {"intermediate_dense": {"bias": 1.7117397785186768, "kernel": 98.74652862548828}, "output_dense": {"bias": 0.5619156360626221, "kernel": 92.54859161376953}}, "final_layer_norm": {"bias": 2.2949581146240234, "scale": 20.386058807373047}, "layer_norm": {"bias": 1.7546437978744507, "scale": 22.054208755493164}}, "11": {"attention": {"k_proj": {"bias": 0.2500627934932709, "kernel": 47.698875427246094}, "out_proj": {"bias": 1.1115316152572632, "kernel": 49.86783981323242}, "q_proj": {"bias": 2.4702019691467285, "kernel": 47.44954299926758}, "v_proj": {"bias": 0.3839545249938965, "kernel": 50.3409423828125}}, "feed_forward": {"intermediate_dense": {"bias": 1.7618391513824463, "kernel": 99.50746154785156}, "output_dense": {"bias": 0.5432395935058594, "kernel": 94.19831085205078}}, "final_layer_norm": {"bias": 2.2942888736724854, "scale": 20.414934158325195}, "layer_norm": {"bias": 1.7346091270446777, "scale": 22.51268768310547}}, "12": {"attention": {"k_proj": {"bias": 0.23851683735847473, "kernel": 48.23259735107422}, "out_proj": {"bias": 1.0888965129852295, "kernel": 50.09956741333008}, "q_proj": {"bias": 2.3461594581604004, "kernel": 48.003665924072266}, "v_proj": {"bias": 0.3631010949611664, "kernel": 50.46704864501953}}, "feed_forward": {"intermediate_dense": {"bias": 1.803997278213501, "kernel": 100.34449768066406}, "output_dense": {"bias": 0.5307854413986206, "kernel": 95.76693725585938}}, "final_layer_norm": {"bias": 2.242570400238037, "scale": 20.37445640563965}, "layer_norm": {"bias": 1.8127485513687134, "scale": 22.940025329589844}}, "13": {"attention": {"k_proj": {"bias": 0.23810593783855438, "kernel": 50.11332702636719}, "out_proj": {"bias": 1.0852420330047607, "kernel": 49.833316802978516}, "q_proj": {"bias": 2.3189704418182373, "kernel": 49.995277404785156}, "v_proj": {"bias": 0.37377816438674927, "kernel": 49.975830078125}}, "feed_forward": {"intermediate_dense": {"bias": 1.856095552444458, "kernel": 100.97918701171875}, "output_dense": {"bias": 0.5486018657684326, "kernel": 96.21040344238281}}, "final_layer_norm": {"bias": 2.139707088470459, "scale": 20.528472900390625}, "layer_norm": {"bias": 1.9281730651855469, "scale": 23.274127960205078}}, "14": {"attention": {"k_proj": {"bias": 0.2723248600959778, "kernel": 50.209617614746094}, "out_proj": {"bias": 1.2504926919937134, "kernel": 48.13068771362305}, "q_proj": {"bias": 2.3632712364196777, "kernel": 50.29181671142578}, "v_proj": {"bias": 0.35957953333854675, "kernel": 47.67962646484375}}, "feed_forward": {"intermediate_dense": {"bias": 1.8947381973266602, "kernel": 101.58688354492188}, "output_dense": {"bias": 0.5726581811904907, "kernel": 97.57473754882812}}, "final_layer_norm": {"bias": 2.2654528617858887, "scale": 20.67388153076172}, "layer_norm": {"bias": 2.0739710330963135, "scale": 23.353639602661133}}, "15": {"attention": {"k_proj": {"bias": 0.2381843477487564, "kernel": 50.35422897338867}, "out_proj": {"bias": 1.3084888458251953, "kernel": 48.870052337646484}, "q_proj": {"bias": 2.5118086338043213, "kernel": 50.4517822265625}, "v_proj": {"bias": 0.4177994132041931, "kernel": 48.54165267944336}}, "feed_forward": {"intermediate_dense": {"bias": 1.90169095993042, "kernel": 101.39517211914062}, "output_dense": {"bias": 0.728958010673523, "kernel": 98.26675415039062}}, "final_layer_norm": {"bias": 2.2084929943084717, "scale": 20.776508331298828}, "layer_norm": {"bias": 2.2970352172851562, "scale": 23.787940979003906}}, "16": {"attention": {"k_proj": {"bias": 0.20325782895088196, "kernel": 50.171905517578125}, "out_proj": {"bias": 1.232241153717041, "kernel": 48.22490692138672}, "q_proj": {"bias": 2.600128173828125, "kernel": 50.07044982910156}, "v_proj": {"bias": 0.36339572072029114, "kernel": 47.90862274169922}}, "feed_forward": {"intermediate_dense": {"bias": 1.8926692008972168, "kernel": 102.01324462890625}, "output_dense": {"bias": 0.7524616718292236, "kernel": 99.08518981933594}}, "final_layer_norm": {"bias": 2.253352642059326, "scale": 21.251562118530273}, "layer_norm": {"bias": 2.222353458404541, "scale": 22.561655044555664}}, "17": {"attention": {"k_proj": {"bias": 0.19596010446548462, "kernel": 50.29098892211914}, "out_proj": {"bias": 1.1670058965682983, "kernel": 47.52079772949219}, "q_proj": {"bias": 2.6935129165649414, "kernel": 50.364601135253906}, "v_proj": {"bias": 0.4015589952468872, "kernel": 47.18798828125}}, "feed_forward": {"intermediate_dense": {"bias": 1.9025218486785889, "kernel": 103.06841278076172}, "output_dense": {"bias": 0.7684204578399658, "kernel": 99.4727783203125}}, "final_layer_norm": {"bias": 2.332897663116455, "scale": 21.796119689941406}, "layer_norm": {"bias": 2.1179938316345215, "scale": 21.90562629699707}}, "18": {"attention": {"k_proj": {"bias": 0.21306441724300385, "kernel": 50.75135803222656}, "out_proj": {"bias": 1.2695965766906738, "kernel": 48.55261993408203}, "q_proj": {"bias": 2.566439390182495, "kernel": 51.118408203125}, "v_proj": {"bias": 0.42802894115448, "kernel": 48.09583282470703}}, "feed_forward": {"intermediate_dense": {"bias": 1.9341740608215332, "kernel": 103.33280944824219}, "output_dense": {"bias": 0.885681688785553, "kernel": 101.10433959960938}}, "final_layer_norm": {"bias": 2.4191389083862305, "scale": 21.78841781616211}, "layer_norm": {"bias": 2.3001561164855957, "scale": 23.88214874267578}}, "19": {"attention": {"k_proj": {"bias": 0.19877497851848602, "kernel": 49.8792839050293}, "out_proj": {"bias": 1.2443280220031738, "kernel": 48.387168884277344}, "q_proj": {"bias": 2.83821439743042, "kernel": 50.26993179321289}, "v_proj": {"bias": 0.39237794280052185, "kernel": 47.64271545410156}}, "feed_forward": {"intermediate_dense": {"bias": 1.984041452407837, "kernel": 103.87921142578125}, "output_dense": {"bias": 0.9528074860572815, "kernel": 101.99952697753906}}, "final_layer_norm": {"bias": 2.3636670112609863, "scale": 22.150318145751953}, "layer_norm": {"bias": 2.209664821624756, "scale": 22.944355010986328}}, "2": {"attention": {"k_proj": {"bias": 0.22293740510940552, "kernel": 46.508689880371094}, "out_proj": {"bias": 1.2349822521209717, "kernel": 44.43158721923828}, "q_proj": {"bias": 3.102576732635498, "kernel": 46.24449920654297}, "v_proj": {"bias": 0.32414352893829346, "kernel": 44.35133361816406}}, "feed_forward": {"intermediate_dense": {"bias": 1.729475498199463, "kernel": 99.74760437011719}, "output_dense": {"bias": 0.6975946426391602, "kernel": 88.37123107910156}}, "final_layer_norm": {"bias": 1.5292998552322388, "scale": 21.104074478149414}, "layer_norm": {"bias": 1.7049822807312012, "scale": 21.620132446289062}}, "20": {"attention": {"k_proj": {"bias": 0.17897281050682068, "kernel": 49.828556060791016}, "out_proj": {"bias": 1.2720823287963867, "kernel": 47.74470520019531}, "q_proj": {"bias": 2.7510714530944824, "kernel": 50.551265716552734}, "v_proj": {"bias": 0.36270517110824585, "kernel": 46.669921875}}, "feed_forward": {"intermediate_dense": {"bias": 1.978695273399353, "kernel": 105.0912857055664}, "output_dense": {"bias": 1.0657472610473633, "kernel": 102.6424331665039}}, "final_layer_norm": {"bias": 2.3712844848632812, "scale": 23.086246490478516}, "layer_norm": {"bias": 2.184242010116577, "scale": 23.00891876220703}}, "21": {"attention": {"k_proj": {"bias": 0.1918519139289856, "kernel": 50.296329498291016}, "out_proj": {"bias": 1.3105319738388062, "kernel": 47.7496337890625}, "q_proj": {"bias": 2.719822883605957, "kernel": 51.077423095703125}, "v_proj": {"bias": 0.4093472957611084, "kernel": 46.88066101074219}}, "feed_forward": {"intermediate_dense": {"bias": 2.0147135257720947, "kernel": 105.262451171875}, "output_dense": {"bias": 1.1490479707717896, "kernel": 102.94585418701172}}, "final_layer_norm": {"bias": 2.3835391998291016, "scale": 22.749361038208008}, "layer_norm": {"bias": 2.236501455307007, "scale": 23.257896423339844}}, "22": {"attention": {"k_proj": {"bias": 0.20655860006809235, "kernel": 50.70651626586914}, "out_proj": {"bias": 1.236067533493042, "kernel": 47.31515884399414}, "q_proj": {"bias": 2.776611566543579, "kernel": 51.042152404785156}, "v_proj": {"bias": 0.359075129032135, "kernel": 47.19292449951172}}, "feed_forward": {"intermediate_dense": {"bias": 1.9531865119934082, "kernel": 105.61796569824219}, "output_dense": {"bias": 1.1664657592773438, "kernel": 102.2547607421875}}, "final_layer_norm": {"bias": 2.2776260375976562, "scale": 22.2520694732666}, "layer_norm": {"bias": 2.2378594875335693, "scale": 22.298477172851562}}, "23": {"attention": {"k_proj": {"bias": 0.2479013055562973, "kernel": 51.734161376953125}, "out_proj": {"bias": 1.3774523735046387, "kernel": 48.390968322753906}, "q_proj": {"bias": 2.6543586254119873, "kernel": 51.83258056640625}, "v_proj": {"bias": 0.5280159115791321, "kernel": 49.03071594238281}}, "feed_forward": {"intermediate_dense": {"bias": 1.9238386154174805, "kernel": 105.43772888183594}, "output_dense": {"bias": 1.142749547958374, "kernel": 103.04647064208984}}, "final_layer_norm": {"bias": 2.552396297454834, "scale": 22.24652862548828}, "layer_norm": {"bias": 2.7072105407714844, "scale": 23.703632354736328}}, "24": {"attention": {"k_proj": {"bias": 0.2229425311088562, "kernel": 50.472808837890625}, "out_proj": {"bias": 1.435978651046753, "kernel": 50.262542724609375}, "q_proj": {"bias": 2.7695631980895996, "kernel": 50.45745849609375}, "v_proj": {"bias": 0.48706698417663574, "kernel": 50.34023666381836}}, "feed_forward": {"intermediate_dense": {"bias": 2.035344123840332, "kernel": 104.8966293334961}, "output_dense": {"bias": 1.183034896850586, "kernel": 105.89578247070312}}, "final_layer_norm": {"bias": 2.6251282691955566, "scale": 22.29348373413086}, "layer_norm": {"bias": 2.4621849060058594, "scale": 23.277725219726562}}, "25": {"attention": {"k_proj": {"bias": 0.21065406501293182, "kernel": 50.781959533691406}, "out_proj": {"bias": 1.2442545890808105, "kernel": 48.13123321533203}, "q_proj": {"bias": 2.870298147201538, "kernel": 50.558074951171875}, "v_proj": {"bias": 0.5664586424827576, "kernel": 48.67639923095703}}, "feed_forward": {"intermediate_dense": {"bias": 1.9374840259552002, "kernel": 105.15123748779297}, "output_dense": {"bias": 1.0432946681976318, "kernel": 105.82373809814453}}, "final_layer_norm": {"bias": 2.360422134399414, "scale": 22.809398651123047}, "layer_norm": {"bias": 2.59881591796875, "scale": 22.279644012451172}}, "26": {"attention": {"k_proj": {"bias": 0.22183993458747864, "kernel": 51.04751205444336}, "out_proj": {"bias": 1.1648166179656982, "kernel": 48.92171096801758}, "q_proj": {"bias": 2.8587160110473633, "kernel": 50.80378723144531}, "v_proj": {"bias": 0.475290447473526, "kernel": 49.529388427734375}}, "feed_forward": {"intermediate_dense": {"bias": 2.041374683380127, "kernel": 104.55793762207031}, "output_dense": {"bias": 1.0075979232788086, "kernel": 102.93330383300781}}, "final_layer_norm": {"bias": 2.006486415863037, "scale": 21.653133392333984}, "layer_norm": {"bias": 2.478868246078491, "scale": 22.6640625}}, "27": {"attention": {"k_proj": {"bias": 0.43229615688323975, "kernel": 51.916725158691406}, "out_proj": {"bias": 1.4090557098388672, "kernel": 50.325904846191406}, "q_proj": {"bias": 2.6127448081970215, "kernel": 51.7978515625}, "v_proj": {"bias": 0.5855052471160889, "kernel": 50.790611267089844}}, "feed_forward": {"intermediate_dense": {"bias": 2.196718692779541, "kernel": 102.87921142578125}, "output_dense": {"bias": 0.8832440972328186, "kernel": 102.62045288085938}}, "final_layer_norm": {"bias": 2.280376434326172, "scale": 20.95566177368164}, "layer_norm": {"bias": 2.5605921745300293, "scale": 23.58169937133789}}, "28": {"attention": {"k_proj": {"bias": 0.45622092485427856, "kernel": 52.793487548828125}, "out_proj": {"bias": 1.4313125610351562, "kernel": 50.987464904785156}, "q_proj": {"bias": 2.7933201789855957, "kernel": 52.45018005371094}, "v_proj": {"bias": 0.465924471616745, "kernel": 51.342491149902344}}, "feed_forward": {"intermediate_dense": {"bias": 2.139850616455078, "kernel": 102.88671875}, "output_dense": {"bias": 0.7735686302185059, "kernel": 104.78712463378906}}, "final_layer_norm": {"bias": 2.202572822570801, "scale": 21.290332794189453}, "layer_norm": {"bias": 2.0484907627105713, "scale": 24.346450805664062}}, "29": {"attention": {"k_proj": {"bias": 0.21946273744106293, "kernel": 49.203609466552734}, "out_proj": {"bias": 1.4073214530944824, "kernel": 53.4610595703125}, "q_proj": {"bias": 2.729091167449951, "kernel": 49.02506637573242}, "v_proj": {"bias": 0.42470136284828186, "kernel": 53.37601089477539}}, "feed_forward": {"intermediate_dense": {"bias": 2.1310319900512695, "kernel": 103.58599090576172}, "output_dense": {"bias": 0.9014177322387695, "kernel": 109.07768249511719}}, "final_layer_norm": {"bias": 2.3953723907470703, "scale": 22.446285247802734}, "layer_norm": {"bias": 2.1818008422851562, "scale": 25.359933853149414}}, "3": {"attention": {"k_proj": {"bias": 0.26286524534225464, "kernel": 50.66347885131836}, "out_proj": {"bias": 1.402914047241211, "kernel": 47.055397033691406}, "q_proj": {"bias": 2.7446746826171875, "kernel": 50.88633728027344}, "v_proj": {"bias": 0.3097341060638428, "kernel": 47.41233825683594}}, "feed_forward": {"intermediate_dense": {"bias": 1.7387281656265259, "kernel": 101.31109619140625}, "output_dense": {"bias": 0.6584216356277466, "kernel": 91.24249267578125}}, "final_layer_norm": {"bias": 1.818251132965088, "scale": 21.199321746826172}, "layer_norm": {"bias": 1.888195514678955, "scale": 23.423542022705078}}, "30": {"attention": {"k_proj": {"bias": 0.3396183252334595, "kernel": 51.11002731323242}, "out_proj": {"bias": 1.2126681804656982, "kernel": 49.820674896240234}, "q_proj": {"bias": 2.8068060874938965, "kernel": 51.180885314941406}, "v_proj": {"bias": 0.4783928692340851, "kernel": 50.18727493286133}}, "feed_forward": {"intermediate_dense": {"bias": 2.06760835647583, "kernel": 104.07792663574219}, "output_dense": {"bias": 0.8458258509635925, "kernel": 108.11724090576172}}, "final_layer_norm": {"bias": 2.230673313140869, "scale": 23.58123779296875}, "layer_norm": {"bias": 2.303117036819458, "scale": 25.078384399414062}}, "31": {"attention": {"k_proj": {"bias": 0.4084402322769165, "kernel": 49.54199981689453}, "out_proj": {"bias": 1.1298390626907349, "kernel": 50.54130554199219}, "q_proj": {"bias": 2.605281352996826, "kernel": 49.6292724609375}, "v_proj": {"bias": 0.5218612551689148, "kernel": 50.690101623535156}}, "feed_forward": {"intermediate_dense": {"bias": 2.1498327255249023, "kernel": 102.74031829833984}, "output_dense": {"bias": 1.0199881792068481, "kernel": 105.45809936523438}}, "final_layer_norm": {"bias": 2.1171905994415283, "scale": 23.458181381225586}, "layer_norm": {"bias": 2.2760095596313477, "scale": 24.785846710205078}}, "32": {"attention": {"k_proj": {"bias": 0.29190176725387573, "kernel": 48.523231506347656}, "out_proj": {"bias": 1.1341676712036133, "kernel": 49.7735481262207}, "q_proj": {"bias": 2.8203587532043457, "kernel": 48.477691650390625}, "v_proj": {"bias": 0.39520299434661865, "kernel": 50.0748405456543}}, "feed_forward": {"intermediate_dense": {"bias": 2.0911455154418945, "kernel": 101.61611938476562}, "output_dense": {"bias": 1.0767910480499268, "kernel": 104.828369140625}}, "final_layer_norm": {"bias": 2.1094062328338623, "scale": 23.85711669921875}, "layer_norm": {"bias": 2.2725131511688232, "scale": 25.117982864379883}}, "33": {"attention": {"k_proj": {"bias": 0.30211469531059265, "kernel": 48.42399597167969}, "out_proj": {"bias": 1.1708109378814697, "kernel": 49.60016632080078}, "q_proj": {"bias": 2.959620952606201, "kernel": 48.389739990234375}, "v_proj": {"bias": 0.41753631830215454, "kernel": 49.881370544433594}}, "feed_forward": {"intermediate_dense": {"bias": 2.096595048904419, "kernel": 99.9796142578125}, "output_dense": {"bias": 1.0682170391082764, "kernel": 103.541259765625}}, "final_layer_norm": {"bias": 2.007560968399048, "scale": 23.61741828918457}, "layer_norm": {"bias": 2.442840814590454, "scale": 25.323089599609375}}, "34": {"attention": {"k_proj": {"bias": 0.3150138854980469, "kernel": 47.643341064453125}, "out_proj": {"bias": 1.423769235610962, "kernel": 51.03999328613281}, "q_proj": {"bias": 2.888617515563965, "kernel": 47.66726303100586}, "v_proj": {"bias": 0.38617265224456787, "kernel": 50.97981262207031}}, "feed_forward": {"intermediate_dense": {"bias": 2.187166213989258, "kernel": 98.79045104980469}, "output_dense": {"bias": 1.0054786205291748, "kernel": 102.78819274902344}}, "final_layer_norm": {"bias": 1.9427646398544312, "scale": 23.23785972595215}, "layer_norm": {"bias": 2.5134119987487793, "scale": 25.711688995361328}}, "35": {"attention": {"k_proj": {"bias": 0.40803784132003784, "kernel": 49.40678405761719}, "out_proj": {"bias": 1.349858045578003, "kernel": 49.86473846435547}, "q_proj": {"bias": 2.581186294555664, "kernel": 49.70142364501953}, "v_proj": {"bias": 0.4777841567993164, "kernel": 49.73640823364258}}, "feed_forward": {"intermediate_dense": {"bias": 2.2715296745300293, "kernel": 97.324462890625}, "output_dense": {"bias": 0.8976683616638184, "kernel": 101.4354248046875}}, "final_layer_norm": {"bias": 2.017836093902588, "scale": 23.347476959228516}, "layer_norm": {"bias": 2.329078435897827, "scale": 26.2806396484375}}, "36": {"attention": {"k_proj": {"bias": 0.2764836549758911, "kernel": 46.63423156738281}, "out_proj": {"bias": 1.38099205493927, "kernel": 51.14894104003906}, "q_proj": {"bias": 2.6797680854797363, "kernel": 46.574954986572266}, "v_proj": {"bias": 0.36416226625442505, "kernel": 51.36273956298828}}, "feed_forward": {"intermediate_dense": {"bias": 2.155928134918213, "kernel": 96.31394958496094}, "output_dense": {"bias": 0.9205336570739746, "kernel": 101.03855895996094}}, "final_layer_norm": {"bias": 1.7051949501037598, "scale": 23.84616470336914}, "layer_norm": {"bias": 2.04105806350708, "scale": 25.763103485107422}}, "37": {"attention": {"k_proj": {"bias": 0.5623797178268433, "kernel": 45.71692657470703}, "out_proj": {"bias": 1.6310793161392212, "kernel": 51.138118743896484}, "q_proj": {"bias": 2.395171880722046, "kernel": 45.76018524169922}, "v_proj": {"bias": 0.3581632673740387, "kernel": 51.03189468383789}}, "feed_forward": {"intermediate_dense": {"bias": 2.0483310222625732, "kernel": 95.48307800292969}, "output_dense": {"bias": 0.9227561354637146, "kernel": 100.72814178466797}}, "final_layer_norm": {"bias": 1.4903086423873901, "scale": 24.2398738861084}, "layer_norm": {"bias": 2.0101869106292725, "scale": 25.782228469848633}}, "38": {"attention": {"k_proj": {"bias": 0.6424748301506042, "kernel": 43.90583038330078}, "out_proj": {"bias": 1.3247406482696533, "kernel": 50.61767578125}, "q_proj": {"bias": 2.314286231994629, "kernel": 43.89386749267578}, "v_proj": {"bias": 0.40888553857803345, "kernel": 50.517494201660156}}, "feed_forward": {"intermediate_dense": {"bias": 1.9870116710662842, "kernel": 93.46450805664062}, "output_dense": {"bias": 0.906753420829773, "kernel": 98.9403076171875}}, "final_layer_norm": {"bias": 1.5285555124282837, "scale": 24.95897674560547}, "layer_norm": {"bias": 2.1943142414093018, "scale": 26.587631225585938}}, "39": {"attention": {"k_proj": {"bias": 0.6710178256034851, "kernel": 43.74102020263672}, "out_proj": {"bias": 1.6112127304077148, "kernel": 50.4604377746582}, "q_proj": {"bias": 2.1047635078430176, "kernel": 44.102073669433594}, "v_proj": {"bias": 0.39144566655158997, "kernel": 50.153480529785156}}, "feed_forward": {"intermediate_dense": {"bias": 1.9708259105682373, "kernel": 91.68063354492188}, "output_dense": {"bias": 0.9833765029907227, "kernel": 99.22885131835938}}, "final_layer_norm": {"bias": 1.6258282661437988, "scale": 25.595191955566406}, "layer_norm": {"bias": 2.1547136306762695, "scale": 27.168075561523438}}, "4": {"attention": {"k_proj": {"bias": 0.2694867253303528, "kernel": 53.26494598388672}, "out_proj": {"bias": 1.5952560901641846, "kernel": 48.508522033691406}, "q_proj": {"bias": 2.533445358276367, "kernel": 53.46202850341797}, "v_proj": {"bias": 0.35670536756515503, "kernel": 48.84002685546875}}, "feed_forward": {"intermediate_dense": {"bias": 1.7226250171661377, "kernel": 100.87029266357422}, "output_dense": {"bias": 0.8173469305038452, "kernel": 92.466064453125}}, "final_layer_norm": {"bias": 1.9236550331115723, "scale": 20.729034423828125}, "layer_norm": {"bias": 1.9882155656814575, "scale": 23.948692321777344}}, "40": {"attention": {"k_proj": {"bias": 0.6092542409896851, "kernel": 42.94449996948242}, "out_proj": {"bias": 1.5522677898406982, "kernel": 49.07390594482422}, "q_proj": {"bias": 2.02258038520813, "kernel": 43.671382904052734}, "v_proj": {"bias": 0.4459694027900696, "kernel": 48.660919189453125}}, "feed_forward": {"intermediate_dense": {"bias": 1.822402000427246, "kernel": 89.84906005859375}, "output_dense": {"bias": 1.0344607830047607, "kernel": 96.44818878173828}}, "final_layer_norm": {"bias": 1.7864675521850586, "scale": 24.866188049316406}, "layer_norm": {"bias": 2.107056140899658, "scale": 26.691997528076172}}, "41": {"attention": {"k_proj": {"bias": 1.6814589500427246, "kernel": 40.39122772216797}, "out_proj": {"bias": 1.3104321956634521, "kernel": 50.64059066772461}, "q_proj": {"bias": 1.7168387174606323, "kernel": 41.10041427612305}, "v_proj": {"bias": 0.39700430631637573, "kernel": 49.57951354980469}}, "feed_forward": {"intermediate_dense": {"bias": 1.9485447406768799, "kernel": 86.61335754394531}, "output_dense": {"bias": 1.0688600540161133, "kernel": 95.44161224365234}}, "final_layer_norm": {"bias": 2.2872514724731445, "scale": 28.325698852539062}, "layer_norm": {"bias": 2.095804452896118, "scale": 28.469776153564453}}, "42": {"attention": {"k_proj": {"bias": 0.8075435757637024, "kernel": 36.95627212524414}, "out_proj": {"bias": 1.3640947341918945, "kernel": 44.83252716064453}, "q_proj": {"bias": 1.5312429666519165, "kernel": 38.272701263427734}, "v_proj": {"bias": 0.6266118288040161, "kernel": 43.17291259765625}}, "feed_forward": {"intermediate_dense": {"bias": 1.7192597389221191, "kernel": 85.57884979248047}, "output_dense": {"bias": 1.1186156272888184, "kernel": 93.6015853881836}}, "final_layer_norm": {"bias": 1.9876140356063843, "scale": 29.61162757873535}, "layer_norm": {"bias": 1.5676913261413574, "scale": 27.305191040039062}}, "43": {"attention": {"k_proj": {"bias": 1.219111442565918, "kernel": 33.429054260253906}, "out_proj": {"bias": 1.350442886352539, "kernel": 41.2076416015625}, "q_proj": {"bias": 1.3615591526031494, "kernel": 34.21763610839844}, "v_proj": {"bias": 0.5368777513504028, "kernel": 39.10566711425781}}, "feed_forward": {"intermediate_dense": {"bias": 1.737384557723999, "kernel": 84.75312805175781}, "output_dense": {"bias": 0.8779321312904358, "kernel": 91.51695251464844}}, "final_layer_norm": {"bias": 1.9532383680343628, "scale": 31.82501220703125}, "layer_norm": {"bias": 1.704296588897705, "scale": 25.475486755371094}}, "44": {"attention": {"k_proj": {"bias": 2.4949541091918945, "kernel": 34.01667022705078}, "out_proj": {"bias": 1.1109570264816284, "kernel": 44.92375946044922}, "q_proj": {"bias": 1.3070473670959473, "kernel": 34.36231231689453}, "v_proj": {"bias": 0.38874804973602295, "kernel": 44.021461486816406}}, "feed_forward": {"intermediate_dense": {"bias": 1.812990665435791, "kernel": 83.68305969238281}, "output_dense": {"bias": 0.8185780644416809, "kernel": 89.15675354003906}}, "final_layer_norm": {"bias": 1.9702775478363037, "scale": 33.993751525878906}, "layer_norm": {"bias": 1.5985355377197266, "scale": 25.498332977294922}}, "45": {"attention": {"k_proj": {"bias": 2.0529022216796875, "kernel": 33.83161163330078}, "out_proj": {"bias": 0.9920564889907837, "kernel": 48.51708984375}, "q_proj": {"bias": 1.3887860774993896, "kernel": 33.97296905517578}, "v_proj": {"bias": 0.4346076250076294, "kernel": 48.68134307861328}}, "feed_forward": {"intermediate_dense": {"bias": 1.9257898330688477, "kernel": 80.37091827392578}, "output_dense": {"bias": 0.9466295838356018, "kernel": 84.56382751464844}}, "final_layer_norm": {"bias": 1.6998722553253174, "scale": 32.70838165283203}, "layer_norm": {"bias": 1.5237869024276733, "scale": 24.026718139648438}}, "46": {"attention": {"k_proj": {"bias": 1.5453729629516602, "kernel": 34.99570846557617}, "out_proj": {"bias": 0.7569196224212646, "kernel": 50.94432067871094}, "q_proj": {"bias": 1.5501949787139893, "kernel": 35.083892822265625}, "v_proj": {"bias": 0.37976646423339844, "kernel": 51.70716857910156}}, "feed_forward": {"intermediate_dense": {"bias": 1.9667569398880005, "kernel": 74.7667236328125}, "output_dense": {"bias": 1.100508451461792, "kernel": 74.91085815429688}}, "final_layer_norm": {"bias": 1.6546576023101807, "scale": 28.219646453857422}, "layer_norm": {"bias": 1.342303991317749, "scale": 22.941308975219727}}, "47": {"attention": {"k_proj": {"bias": 0.31063246726989746, "kernel": 37.36616897583008}, "out_proj": {"bias": 0.6397010087966919, "kernel": 45.241363525390625}, "q_proj": {"bias": 1.6792919635772705, "kernel": 37.94398880004883}, "v_proj": {"bias": 0.3538339138031006, "kernel": 46.22050476074219}}, "feed_forward": {"intermediate_dense": {"bias": 2.031810760498047, "kernel": 72.15599060058594}, "output_dense": {"bias": 0.6105413436889648, "kernel": 68.52323150634766}}, "final_layer_norm": {"bias": 1.5555658340454102, "scale": 23.038026809692383}, "layer_norm": {"bias": 1.0648114681243896, "scale": 20.196226119995117}}, "5": {"attention": {"k_proj": {"bias": 0.23068922758102417, "kernel": 48.48977279663086}, "out_proj": {"bias": 1.5695205926895142, "kernel": 49.58379364013672}, "q_proj": {"bias": 2.636868953704834, "kernel": 48.62889862060547}, "v_proj": {"bias": 0.3259868621826172, "kernel": 50.311851501464844}}, "feed_forward": {"intermediate_dense": {"bias": 1.649585485458374, "kernel": 100.93144226074219}, "output_dense": {"bias": 0.8427544832229614, "kernel": 91.76881408691406}}, "final_layer_norm": {"bias": 2.1725611686706543, "scale": 20.89813232421875}, "layer_norm": {"bias": 2.0255513191223145, "scale": 23.038808822631836}}, "6": {"attention": {"k_proj": {"bias": 0.3303305506706238, "kernel": 50.40907287597656}, "out_proj": {"bias": 1.554028868675232, "kernel": 49.0322265625}, "q_proj": {"bias": 2.703887701034546, "kernel": 50.87598419189453}, "v_proj": {"bias": 0.32478368282318115, "kernel": 49.51765441894531}}, "feed_forward": {"intermediate_dense": {"bias": 1.620659351348877, "kernel": 100.0206069946289}, "output_dense": {"bias": 0.7007467150688171, "kernel": 91.39401245117188}}, "final_layer_norm": {"bias": 2.4702341556549072, "scale": 20.376611709594727}, "layer_norm": {"bias": 1.9957983493804932, "scale": 23.686267852783203}}, "7": {"attention": {"k_proj": {"bias": 0.3042152523994446, "kernel": 49.961692810058594}, "out_proj": {"bias": 1.3766974210739136, "kernel": 49.27983856201172}, "q_proj": {"bias": 2.458479881286621, "kernel": 50.367801666259766}, "v_proj": {"bias": 0.4104377031326294, "kernel": 49.22698211669922}}, "feed_forward": {"intermediate_dense": {"bias": 1.627145767211914, "kernel": 99.72712707519531}, "output_dense": {"bias": 0.5415225625038147, "kernel": 91.11566162109375}}, "final_layer_norm": {"bias": 2.324662685394287, "scale": 20.59018325805664}, "layer_norm": {"bias": 1.9219211339950562, "scale": 22.409427642822266}}, "8": {"attention": {"k_proj": {"bias": 0.292106032371521, "kernel": 49.483314514160156}, "out_proj": {"bias": 1.1988368034362793, "kernel": 49.76782989501953}, "q_proj": {"bias": 2.43380069732666, "kernel": 49.27770233154297}, "v_proj": {"bias": 0.336185097694397, "kernel": 49.90403747558594}}, "feed_forward": {"intermediate_dense": {"bias": 1.6834964752197266, "kernel": 99.28549194335938}, "output_dense": {"bias": 0.5039160251617432, "kernel": 90.5440902709961}}, "final_layer_norm": {"bias": 2.276369571685791, "scale": 20.358753204345703}, "layer_norm": {"bias": 1.8604187965393066, "scale": 22.826251983642578}}, "9": {"attention": {"k_proj": {"bias": 0.3227398991584778, "kernel": 50.32160186767578}, "out_proj": {"bias": 1.4219694137573242, "kernel": 50.671417236328125}, "q_proj": {"bias": 2.3705806732177734, "kernel": 50.517311096191406}, "v_proj": {"bias": 0.3579637110233307, "kernel": 51.04570770263672}}, "feed_forward": {"intermediate_dense": {"bias": 1.76076340675354, "kernel": 97.91405487060547}, "output_dense": {"bias": 0.6380627155303955, "kernel": 90.99038696289062}}, "final_layer_norm": {"bias": 2.2011656761169434, "scale": 19.671579360961914}, "layer_norm": {"bias": 1.946717381477356, "scale": 24.423818588256836}}}, "pos_conv_embed": {"conv": {"bias": 5.6507768630981445, "weight_g": 8.974504470825195, "weight_v": 85.95291137695312}}}, "feature_extractor": {"conv_layers": {"0": {"conv": {"bias": 2.0290679931640625, "kernel": 20.55536460876465}, "layer_norm": {"bias": 4.550922393798828, "scale": 16.167570114135742}}, "1": {"conv": {"bias": 1.7790228128433228, "kernel": 51.24136734008789}, "layer_norm": {"bias": 5.962646961212158, "scale": 23.268157958984375}}, "2": {"conv": {"bias": 1.140576720237732, "kernel": 46.50312042236328}, "layer_norm": {"bias": 4.176670551300049, "scale": 20.370853424072266}}, "3": {"conv": {"bias": 0.6725863218307495, "kernel": 44.397525787353516}, "layer_norm": {"bias": 3.888174533843994, "scale": 17.53795051574707}}, "4": {"conv": {"bias": 0.6373162269592285, "kernel": 41.314056396484375}, "layer_norm": {"bias": 2.385471820831299, "scale": 16.34571647644043}}, "5": {"conv": {"bias": 0.5147221684455872, "kernel": 37.479759216308594}, "layer_norm": {"bias": 2.020900011062622, "scale": 17.064470291137695}}, "6": {"conv": {"bias": 0.4947893023490906, "kernel": 40.64780044555664}, "layer_norm": {"bias": 0.5876954793930054, "scale": 19.058603286743164}}}}, "feature_projection": {"layer_norm": {"bias": 6.348620891571045, "scale": 16.567195892333984}, "projection": {"bias": 1.7358717918395996, "kernel": 35.4318733215332}}, "masked_spec_embed": 11.914372444152832}}, "train/learning_rate": 0.0001997361978283152, "train/loss": 1.570224642753601, "train/param_norm": 1196.265869140625, "_runtime": 7703, "_timestamp": 1659711417, "_step": 2500, "_wandb": {"runtime": 7705}}
wandb/run-20220805_124834-3ep5xqhh/logs/debug-internal.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d3da4ff9acc37b96e42cab317a91c67a7f027a3b091f75199a9bb88d1da66494
3
+ size 488530
wandb/run-20220805_124834-3ep5xqhh/logs/debug.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:48fa6c02e11489989284e294f2bdd85e2a91addb6c39bc72e5836b6a56820370
3
+ size 6061
wandb/run-20220805_124834-3ep5xqhh/run-3ep5xqhh.wandb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d8246b0d8a9370c41447785409866c16f57fb082173b041d6c86503a87ba094
3
+ size 2319945
wandb/run-20220805_152536-2fzkf8n5/files/code/run_flax_speech_recognition_ctc.py ADDED
@@ -0,0 +1,1631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2022 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the Flax library models for connectionist temporal classification (CTC) speech recognition.
18
+ """
19
+ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
+
21
+ import logging
22
+ import math
23
+ import os
24
+ import re
25
+ import sys
26
+ import time
27
+ from dataclasses import dataclass, field
28
+ from pathlib import Path
29
+ from typing import Any, Callable, Dict, List, Optional, Union
30
+
31
+ import datasets
32
+ import numpy as np
33
+ from datasets import DatasetDict, load_dataset, load_metric
34
+ from tqdm import tqdm
35
+
36
+ import flax
37
+ import jax
38
+ import jax.numpy as jnp
39
+ import optax
40
+ import transformers
41
+ import wandb as wandb
42
+ from flax import core, jax_utils, struct, traverse_util
43
+ from flax.jax_utils import unreplicate, pad_shard_unpad
44
+ from flax.training.common_utils import get_metrics, shard, shard_prng_key
45
+ from huggingface_hub import Repository
46
+ from models import Wav2Vec2Config, FlaxWav2Vec2ForCTC
47
+ from optax._src import linear_algebra
48
+ from transformers import (
49
+ AutoFeatureExtractor,
50
+ AutoProcessor,
51
+ AutoTokenizer,
52
+ HfArgumentParser,
53
+ TrainingArguments,
54
+ is_tensorboard_available,
55
+ set_seed,
56
+ )
57
+ from transformers.file_utils import get_full_repo_name
58
+ from transformers.utils import check_min_version
59
+ from transformers.utils.versions import require_version
60
+
61
+
62
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
63
+ check_min_version("4.17.0.dev0")
64
+
65
+ require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
66
+
67
+ logger = logging.getLogger(__name__)
68
+
69
+
70
+ @flax.struct.dataclass
71
+ class ModelArguments:
72
+ """
73
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
74
+ """
75
+
76
+ model_name_or_path: str = field(
77
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
78
+ )
79
+ config_name: Optional[str] = field(
80
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
81
+ )
82
+ tokenizer_name: Optional[str] = field(
83
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
84
+ )
85
+ feature_extractor_name: Optional[str] = field(
86
+ default=None, metadata={"help": "feature extractor name or path if not the same as model_name"}
87
+ )
88
+ cache_dir: Optional[str] = field(
89
+ default=None,
90
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
91
+ )
92
+ use_fast_tokenizer: bool = field(
93
+ default=True,
94
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
95
+ )
96
+ model_revision: str = field(
97
+ default="main",
98
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
99
+ )
100
+ use_auth_token: bool = field(
101
+ default=False,
102
+ metadata={
103
+ "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
104
+ "with private models)."
105
+ },
106
+ )
107
+ freeze_feature_encoder: bool = field(
108
+ default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
109
+ )
110
+ attention_dropout: float = field(
111
+ default=0.0, metadata={"help": "The dropout ratio for the attention probabilities."}
112
+ )
113
+ activation_dropout: float = field(
114
+ default=0.1,
115
+ metadata={
116
+ "help": "The hidden activation dropout probability in the embeddings, encoder, and pooler."
117
+ },
118
+ )
119
+ hidden_dropout: float = field(
120
+ default=0.1,
121
+ metadata={
122
+ "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler."
123
+ },
124
+ )
125
+ feat_proj_dropout: float = field(
126
+ default=0.0,
127
+ metadata={
128
+ "help": "The feat proj dropout probability for feature encoder representations."
129
+ },
130
+ )
131
+ final_dropout: float = field(
132
+ default=0.0,
133
+ metadata={"help": "The dropout probability for the final projection layer."},
134
+ )
135
+ mask_time_prob: float = field(
136
+ default=0.1,
137
+ metadata={
138
+ "help": "The spec aug dropout probability for feature encoder representations."
139
+ },
140
+ )
141
+ mask_time_length: int = field(
142
+ default=10,
143
+ metadata={"help": "Length of vector span to mask along the time axis."},
144
+ )
145
+ mask_feature_prob: float = field(
146
+ default=0.0,
147
+ metadata={
148
+ "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector"
149
+ "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis."
150
+ },
151
+ )
152
+ mask_feature_length: int = field(
153
+ default=10,
154
+ metadata={"help": "Length of vector span to mask along the feature axis."},
155
+ )
156
+ layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."})
157
+ ctc_loss_reduction: Optional[str] = field(
158
+ default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
159
+ )
160
+ ctc_zero_infinity: Optional[bool] = field(
161
+ default=False, metadata={"help": "If True, will try yo aboud the CTC loss goinf to infinity."}
162
+ )
163
+
164
+
165
+ @flax.struct.dataclass
166
+ class DataTrainingArguments:
167
+ """
168
+ Arguments pertaining to what data we are going to input our model for training and eval.
169
+ """
170
+
171
+ dataset_name: str = field(
172
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
173
+ )
174
+ dataset_config_name: Optional[str] = field(
175
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
176
+ )
177
+ text_column: Optional[str] = field(
178
+ default=None,
179
+ metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
180
+ )
181
+ dataset_cache_dir: Optional[str] = field(
182
+ default=None, metadata={"help": "Path to cache directory for saving and loading datasets"}
183
+ )
184
+ overwrite_cache: bool = field(
185
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
186
+ )
187
+ preprocessing_num_workers: Optional[int] = field(
188
+ default=None,
189
+ metadata={"help": "The number of processes to use for the preprocessing."},
190
+ )
191
+ max_train_samples: Optional[int] = field(
192
+ default=None,
193
+ metadata={
194
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
195
+ "value if set."
196
+ },
197
+ )
198
+ max_eval_samples: Optional[int] = field(
199
+ default=None,
200
+ metadata={
201
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
202
+ "value if set."
203
+ },
204
+ )
205
+ max_test_samples: Optional[int] = field(
206
+ default=None,
207
+ metadata={
208
+ "help": "For debugging purposes or quicker training, truncate the number of test examples to this "
209
+ "value if set."
210
+ },
211
+ )
212
+ audio_column_name: str = field(
213
+ default="audio",
214
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
215
+ )
216
+ text_column_name: str = field(
217
+ default="text",
218
+ metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
219
+ )
220
+ max_duration_in_seconds: float = field(
221
+ default=20.0,
222
+ metadata={
223
+ "help": "Truncate audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`"
224
+ },
225
+ )
226
+ min_duration_in_seconds: float = field(
227
+ default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
228
+ )
229
+ max_label_length: Optional[int] = field(
230
+ default=512,
231
+ metadata={
232
+ "help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
233
+ "than this will be filtered."
234
+ },
235
+ )
236
+ min_label_length: Optional[int] = field(
237
+ default=2,
238
+ metadata={
239
+ "help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
240
+ "than this will be filtered."
241
+ },
242
+ )
243
+ pad_input_to_multiple_of: Optional[int] = field(
244
+ default=32000,
245
+ metadata={
246
+ "help": "If set will pad the input sequence to a multiple of the provided value. "
247
+ "This is important to avoid triggering recompilations on TPU."
248
+ },
249
+ )
250
+ pad_target_to_multiple_of: Optional[int] = field(
251
+ default=None,
252
+ metadata={
253
+ "help": "If set will pad the target sequence to a multiple of the provided value. "
254
+ "This is important to avoid triggering recompilations on TPU."
255
+ },
256
+ )
257
+ preprocessing_only: bool = field(
258
+ default=False,
259
+ metadata={
260
+ "help": "Whether to only do data preprocessing and skip training. "
261
+ "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
262
+ "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
263
+ "so that the cached datasets can consequently be loaded in distributed training"
264
+ },
265
+ )
266
+ train_split_name: str = field(
267
+ default="train",
268
+ metadata={
269
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
270
+ },
271
+ )
272
+ eval_split_name: str = field(
273
+ default="validation",
274
+ metadata={
275
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
276
+ },
277
+ )
278
+ do_lower_case: bool = field(
279
+ default=True,
280
+ metadata={"help": "Whether the target text should be lower cased."},
281
+ )
282
+ wandb_project: str = field(
283
+ default="flax-speech-recognition-ctc",
284
+ metadata={"help": "The name of the wandb project."},
285
+ )
286
+ wandb_name: str = field(
287
+ default=None,
288
+ metadata={"help": "The name of the wandb run."},
289
+ )
290
+ wandb_job_type: str = field(
291
+ default="CTC",
292
+ metadata={"help": "The name of the wandb job type."},
293
+ )
294
+ test_split_name: str = field(
295
+ default="test",
296
+ metadata={"help": "The name of the test data set split to use (via the datasets library). Defaults to 'test'"},
297
+ )
298
+ remove_punctuation: bool = field(
299
+ default=False, metadata={"help": "Whether or not to remove punctuation during training."}
300
+ )
301
+ skip_steps: Optional[int] = field(
302
+ default=0,
303
+ metadata={
304
+ "help": "Skip this number of steps. Useful to continue training"
305
+ },
306
+ )
307
+
308
+
309
+ # @flax.struct.dataclass
310
+ @dataclass
311
+ class FlaxTrainingArguments(TrainingArguments):
312
+ precision: str = field(
313
+ default="full",
314
+ metadata={
315
+ "help": "Whether to enable mixed-precision training. If true, the optimizer is stored in half-precision (bfloat16) and computations are executed in half-precision"
316
+ "**Note that this only specifies the dtype of the computation and optimizer state. It does not influence the dtype of model parameters.**"
317
+ },
318
+ )
319
+ matmul_precision: str = field(
320
+ default="default",
321
+ metadata={
322
+ "help": "Default floating-point precision of internal computations used in TPU matrix multiplications and convolutions. "
323
+ "This configuration option controls the default precision for JAX operations that take an optional precision argument (e.g. `lax.conv_general_dilated` and `lax.dot`). "
324
+ "This configuration option does not change the behaviours of such calls with explicit precision arguments; "
325
+ "it only changes the behaviors of calls with no such argument provided. "
326
+ "One of `['highest', 'float32', 'high', 'bfloat16_3x', 'default', 'bfloat16', 'fastest', None]`."
327
+ },
328
+ )
329
+ multisteps: bool = field(
330
+ default=False,
331
+ metadata={
332
+ "help": "Whether to use Optax MultiSteps for gradient accumulation. If `False` (default) and `gradient_accumulation_steps > 1`, "
333
+ "a custom gradient accumulation implementation will be employed."
334
+ },
335
+ )
336
+
337
+
338
+ def to_fp32(t):
339
+ return jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
340
+
341
+
342
+ def to_bf16(t):
343
+ return jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t)
344
+
345
+
346
+ class MixedPrecisionTrainState(struct.PyTreeNode):
347
+ """Train state for use with a single Optax optimizer.
348
+ Adapted from flax train_state https://github.com/google/flax/blob/main/flax/training/train_state.py
349
+
350
+ Synopsis::
351
+
352
+ state = TrainState.create(
353
+ apply_fn=model.apply,
354
+ params=variables['params'],
355
+ tx=tx)
356
+ grad_fn = jax.grad(make_loss_fn(state.apply_fn))
357
+ for batch in data:
358
+ grads = grad_fn(state.params, batch)
359
+ state = state.apply_gradients(grads=grads)
360
+
361
+ Args:
362
+ step: Counter starts at 0 and is incremented by every call to
363
+ `.apply_gradients()`.
364
+ apply_fn: Usually set to `model.apply()`. Kept in this dataclass for
365
+ convenience to have a shorter params list for the `train_step()` function
366
+ in your training loop.
367
+ params: The parameters to be updated by `tx` and used by `apply_fn`.
368
+ tx: An Optax gradient transformation.
369
+ opt_state: The state for `tx`.
370
+ dropout_rng: PRNG key for stochastic operations.
371
+ bf16: Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training.
372
+ """
373
+
374
+ step: int
375
+ apply_fn: Callable = struct.field(pytree_node=False)
376
+ get_attention_mask_fn: Callable = struct.field(pytree_node=False)
377
+ params: core.FrozenDict[str, Any]
378
+ tx: optax.GradientTransformation = struct.field(pytree_node=False)
379
+ opt_state: optax.OptState
380
+ dropout_rng: jnp.ndarray
381
+ max_grad_norm: Optional[float] = 1.0
382
+
383
+ def apply_gradients(self, *, grads, to_dtype, **kwargs):
384
+ """Updates `step`, `params`, `opt_state` and `**kwargs` in return value.
385
+
386
+ Note that internally this function calls `.tx.update()` followed by a call
387
+ to `optax.apply_updates()` to update `params` and `opt_state`.
388
+
389
+ Args:
390
+ grads: Gradients that have the same pytree structure as `.params`.
391
+ **kwargs: Additional dataclass attributes that should be `.replace()`-ed.
392
+
393
+ Returns:
394
+ An updated instance of `self` with `step` incremented by one, `params`
395
+ and `opt_state` updated by applying `grads`, and additional attributes
396
+ replaced as specified by `kwargs`.
397
+ """
398
+
399
+ # clip gradients by global l2 norm
400
+ casted_max_grad_norm = to_dtype(self.max_grad_norm)
401
+ g_norm = linear_algebra.global_norm(grads)
402
+ g_norm = jnp.maximum(casted_max_grad_norm, g_norm)
403
+ grads = jax.tree_util.tree_map(lambda t: (t / g_norm) * casted_max_grad_norm, grads)
404
+
405
+ # perform update step in fp32 and subsequently downcast optimizer states if mixed precision training
406
+ # grads and opt_state in bf16 (need to upcast), params in fp32 (leave as is)
407
+ updates, new_opt_state = self.tx.update(to_fp32(grads), to_fp32(self.opt_state), self.params)
408
+
409
+ new_params = optax.apply_updates(self.params, updates)
410
+ return self.replace(
411
+ step=self.step + 1,
412
+ params=new_params,
413
+ opt_state=to_dtype(new_opt_state),
414
+ **kwargs,
415
+ )
416
+
417
+ @classmethod
418
+ def create(cls, *, apply_fn, params, tx, to_dtype, **kwargs):
419
+ """Creates a new instance with `step=0` and initialized `opt_state`."""
420
+ # downcast optimizer state to bf16 if mixed-precision training
421
+ opt_state = tx.init(to_dtype(params)) if tx is not None else None
422
+ return cls(
423
+ step=0,
424
+ apply_fn=apply_fn,
425
+ params=params,
426
+ tx=tx,
427
+ opt_state=opt_state,
428
+ **kwargs,
429
+ )
430
+
431
+ def replicate(self):
432
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
433
+
434
+
435
+ @flax.struct.dataclass
436
+ class FlaxDataCollatorSpeechSeq2SeqWithPadding:
437
+ """
438
+ Data collator that will dynamically pad the inputs received.
439
+ Args:
440
+ processor ([`Wav2Vec2Processor`])
441
+ The processor used for proccessing the data.
442
+ decoder_start_token_id (:obj: `int`)
443
+ The begin-of-sentence of the decoder.
444
+ input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
445
+ Select a strategy to pad the returned input sequences (according to the model's padding side and padding index)
446
+ among:
447
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
448
+ sequence if provided).
449
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
450
+ maximum acceptable input length for the model if that argument is not provided.
451
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
452
+ different lengths).
453
+ target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
454
+ Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
455
+ See above for details.
456
+ max_input_length (:obj:`float`, `optional`):
457
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
458
+ pad_input_to_multiple_of (:obj:`int`, `optional`):
459
+ If set will pad the input sequence to a multiple of the provided value.
460
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
461
+ 7.5 (Volta).
462
+ pad_target_to_multiple_of (:obj:`int`, `optional`):
463
+ If set will pad the target sequence to a multiple of the provided value.
464
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
465
+ 7.5 (Volta).
466
+ """
467
+
468
+ processor: Any
469
+ input_padding: Union[bool, str] = "longest"
470
+ label_padding: Union[bool, str] = "max_length"
471
+ pad_input_to_multiple_of: Optional[int] = None
472
+ pad_to_multiple_of_label: Optional[int] = None
473
+ max_input_length: Optional[float] = None
474
+ max_label_length: Optional[float] = None
475
+
476
+ def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
477
+ # split inputs and labels since they have to be of different lengths and need
478
+ # different padding methods
479
+ input_features = [{"input_values": feature["input_values"]} for feature in features]
480
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
481
+
482
+ # reformat list to dict and set to pytorch format
483
+ batch = self.processor.feature_extractor.pad(
484
+ input_features,
485
+ max_length=self.max_input_length,
486
+ padding=self.input_padding,
487
+ pad_to_multiple_of=self.pad_input_to_multiple_of,
488
+ return_tensors="np",
489
+ )
490
+
491
+ labels_batch = self.processor.tokenizer.pad(
492
+ label_features,
493
+ max_length=self.max_label_length,
494
+ padding=self.label_padding,
495
+ pad_to_multiple_of=self.pad_to_multiple_of_label,
496
+ return_tensors="np",
497
+ )
498
+
499
+ labels = labels_batch["input_ids"]
500
+ labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1))
501
+ labels = labels.filled(fill_value=-100)
502
+
503
+ batch["labels"] = labels
504
+
505
+ return batch
506
+
507
+
508
+ def get_grouped_indices(
509
+ dataset, batch_size: int, rng: Optional[List[int]] = None, mega_batch_mult: Optional[int] = None
510
+ ) -> np.array:
511
+ """
512
+ Adapted from the `get_length_grouped_indices` function in the PyTorch Trainer utils file (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_pt_utils.py#L486)
513
+ Function that returns a list of indices in which each slice of `batch_size` consecutive indices correspond to elements of similar
514
+ lengths. To do this, the indices are:
515
+
516
+ - randomly permuted (if a JAX rng is specified)
517
+ - grouped in mega-batches of size `mega_batch_mult * batch_size`
518
+ - sorted by length in each mega-batch
519
+
520
+ The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
521
+ maximum length placed first, so that an OOM happens sooner rather than later.
522
+ """
523
+ lengths = dataset["input_length"]
524
+
525
+ # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
526
+ if mega_batch_mult is None:
527
+ mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
528
+ # Just in case, for tiny datasets
529
+ if mega_batch_mult == 0:
530
+ mega_batch_mult = 1
531
+
532
+ # We need to use JAX for the random permutation as the PRNG key will be set based on the seed outside of the sampler.
533
+ num_samples = len(lengths)
534
+ indices = jax.random.permutation(rng, np.arange(num_samples)) if rng is not None else np.arange(num_samples)
535
+
536
+ megabatch_size = mega_batch_mult * batch_size
537
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
538
+ megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches]
539
+
540
+ # The rest is to get the biggest batch first.
541
+ # Since each megabatch is sorted by descending length, the longest element is the first
542
+ megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
543
+ max_idx = np.argmax(megabatch_maximums).item()
544
+ # Switch to put the longest batch in first position
545
+ # (note that this is different to the PT grouped sampler in which we only put the longest element in the first position, and not its batch)
546
+ megabatches[0], megabatches[max_idx] = megabatches[max_idx], megabatches[0]
547
+
548
+ megabatches = np.array([i for megabatch in megabatches for i in megabatch])
549
+
550
+ return megabatches
551
+
552
+
553
+ def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray:
554
+ """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
555
+ the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned."""
556
+ num_samples = len(samples_idx)
557
+ if drop_last:
558
+ samples_to_remove = num_samples % batch_size
559
+ if samples_to_remove != 0:
560
+ samples_idx = samples_idx[:-samples_to_remove]
561
+ sections_split = num_samples // batch_size
562
+ samples_idx = samples_idx.reshape((sections_split, batch_size))
563
+ else:
564
+ sections_split = math.ceil(num_samples / batch_size)
565
+ samples_idx = np.array_split(samples_idx, sections_split)
566
+ return samples_idx
567
+
568
+
569
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
570
+ summary_writer.scalar("train_time", train_time, step)
571
+
572
+ train_metrics = get_metrics(train_metrics)
573
+ for key, vals in train_metrics.items():
574
+ tag = f"train_{key}"
575
+ for i, val in enumerate(vals):
576
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
577
+
578
+
579
+ def write_eval_metric(summary_writer, eval_metrics, step, pred_str=None):
580
+ for metric_name, value in eval_metrics.items():
581
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
582
+
583
+ if pred_str is not None:
584
+ # write output actual predictions for debugging
585
+ summary_writer.text("eval_predictions", "\n".join(pred_str), step)
586
+
587
+
588
+ def write_wandb_log(metrics, step, prefix=None):
589
+ if jax.process_index() == 0:
590
+ log_metrics = {}
591
+ for k, v in metrics.items():
592
+ if "layer" in k:
593
+ log_metrics[f"{k}/"] = v
594
+ elif prefix is not None:
595
+ log_metrics[f"{prefix}/{k}"] = v
596
+ else:
597
+ log_metrics[k] = v
598
+ wandb.log(log_metrics, step)
599
+
600
+
601
+ def write_wandb_pred(pred_str, label_str, step, num_log=50, prefix="eval"):
602
+ if jax.process_index() == 0:
603
+ # convert str data to a wandb compatible format
604
+ str_data = [[label_str[i], pred_str[i]] for i in range(len(pred_str))]
605
+ # we'll log the first 50 predictions for each epoch
606
+ wandb.log(
607
+ {
608
+ f"{prefix}/step_{int(step / 1000)}k": wandb.Table(
609
+ columns=["label_str", "pred_str"], data=str_data[:num_log]
610
+ )
611
+ },
612
+ step,
613
+ )
614
+
615
+
616
+ def create_learning_rate_fn(
617
+ num_train_steps: int, num_warmup_steps: int, learning_rate: float
618
+ ) -> Callable[[int], jnp.array]:
619
+ """Returns a linear warmup, linear_decay learning rate function."""
620
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
621
+ decay_fn = optax.linear_schedule(
622
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
623
+ )
624
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
625
+ return schedule_fn
626
+
627
+
628
+ def ctc_loss(
629
+ logits,
630
+ logits_attention_mask,
631
+ labels,
632
+ blank_id,
633
+ loss_reduction="mean",
634
+ output_emission_dict=False,
635
+ log_epsilon=-100000.0,
636
+ ):
637
+ """Computes CTC loss.
638
+ This function performs forward computation over an FSA with `N * 2` states
639
+ where `N` is the max number of labels. The states are split into two groups:
640
+ Phi states and emission states. a phi-state accepts repetition of
641
+ phi (blank)-symbols and transits to emission state when the correct label is
642
+ observed. An emission state accepts repetition of the label and transits to
643
+ the next phi states at any time (so called epsilon-transition).
644
+ Below, `B` denotes the batch size, `T` denotes the time steps in `logits`,
645
+ and `N` denotes the time steps in `labels`.
646
+ Args:
647
+ logits: (B, T, K)-array containing log-probabilities of each class.
648
+ logitpaddings: (B, T)-array. Padding indicators for `logits`.
649
+ labels: (B, N)-array containing reference integer labels.
650
+ labelpaddings: (B, N)-array. Padding indicators for `labels`. Currently,
651
+ `labels` must be right-padded, i.e. each row of `labelpaddings` must be
652
+ repetition of zeroes, followed by repetition of ones.
653
+ blank_id: Id for blank token.
654
+ loss_reduction: one of "mean", "sum", "default"
655
+ - "none": no reduction is applied.
656
+ - "mean": output loss will be divided by target lengths and then the
657
+ mean over the batch is taken.
658
+ - "sum": output loss are summed over batch
659
+ output_emission_dict: whether to output additional information about the emission probs
660
+ Returns:
661
+ A pair of `(per_seq_loss, aux)`.
662
+ per_seq_loss:
663
+ (B,)-array containing loss values for each sequence in the batch.
664
+ aux: Dictionary containing interim variables used for computing losses.
665
+ aux['logalpha_phi']: (T, B, N+1)-array. Log-forward-probabilities of each
666
+ phi-state corresponding to the n-th label.
667
+ aux['logalpha_emit']: (T, B, N)-array. Log-forward-probabilities of each
668
+ emission-state corresponding to the n-th label.
669
+ aux['logprobs_phi']: (T, B, 1)-array. Probability of the phi-symbol
670
+ corresponding to each time frame.
671
+ aux['logprobs_emit']: (T, B, N)-array. Probability of the n-th label
672
+ corresponding to each time frame.
673
+ """
674
+ # label paddings are indicated by -100
675
+ labelpaddings = labels < 0
676
+ # logit paddings are the inverse of attention_mask
677
+ logitpaddings = ~logits_attention_mask
678
+
679
+ # Copied from https://github.com/tensorflow/lingvo/blob/master/lingvo/jax/layers/ctc_objectives.py
680
+ batchsize, unused_maxinputlen, num_classes = logits.shape
681
+ batchsize_, maxlabellen = labels.shape
682
+
683
+ logprobs = jax.nn.log_softmax(logits)
684
+ labellens = maxlabellen - jnp.sum(labelpaddings, axis=1).astype(jnp.int32)
685
+
686
+ # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1].
687
+ repeat = (labels[:, :-1] == labels[:, 1:]).astype(jnp.float32)
688
+ repeat = jnp.pad(repeat, ((0, 0), (0, 1)))
689
+
690
+ logprobs_phi = logprobs[:, :, blank_id : blank_id + 1] # [B, T, 1]
691
+ logprobs_phi = jnp.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1]
692
+
693
+ one_hot = jax.nn.one_hot(labels, num_classes=num_classes) # [B, N, K]
694
+ logprobs_emit = jnp.einsum("btk,bnk->btn", logprobs, one_hot)
695
+ logprobs_emit = jnp.transpose(logprobs_emit, (1, 0, 2)) # [T, B, N]
696
+
697
+ logalpha_phi_init = jnp.ones((batchsize, maxlabellen + 1)) * log_epsilon # [B, N]
698
+ logalpha_phi_init = logalpha_phi_init.at[:, 0].set(0.0)
699
+ logalpha_emit_init = jnp.ones((batchsize, maxlabellen)) * log_epsilon # [B, N]
700
+
701
+ def loop_body(prev, x):
702
+ prev_phi, prev_emit = prev
703
+ # emit-to-phi epsilon transition, except if the next label is repetition
704
+ prev_phi_orig = prev_phi
705
+ prev_phi = prev_phi.at[:, 1:].set(jnp.logaddexp(prev_phi[:, 1:], prev_emit + log_epsilon * repeat))
706
+
707
+ logprob_emit, logprob_phi, pad = x
708
+
709
+ # phi-to-emit transition
710
+ next_emit = jnp.logaddexp(prev_phi[:, :-1] + logprob_emit, prev_emit + logprob_emit)
711
+ # self-loop transition
712
+ next_phi = prev_phi + logprob_phi
713
+ # emit-to-phi blank transition only when the next label is repetition
714
+ next_phi = next_phi.at[:, 1:].set(
715
+ jnp.logaddexp(next_phi[:, 1:], prev_emit + logprob_phi + log_epsilon * (1.0 - repeat))
716
+ )
717
+
718
+ pad = pad.reshape((batchsize, 1))
719
+ next_emit = pad * prev_emit + (1.0 - pad) * next_emit
720
+ next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi
721
+
722
+ return (next_phi, next_emit), (next_phi, next_emit)
723
+
724
+ xs = (logprobs_emit, logprobs_phi, logitpaddings.transpose((1, 0)))
725
+ _, (logalpha_phi, logalpha_emit) = jax.lax.scan(loop_body, (logalpha_phi_init, logalpha_emit_init), xs)
726
+
727
+ # last row needs to be updated with the last epsilon transition
728
+ logalpha_phi_last = logalpha_phi[-1].at[:, 1:].set(jnp.logaddexp(logalpha_phi[-1, :, 1:], logalpha_emit[-1]))
729
+ logalpha_phi = logalpha_phi.at[-1].set(logalpha_phi_last)
730
+
731
+ # extract per_seq_loss
732
+ one_hot = jax.nn.one_hot(labellens, num_classes=maxlabellen + 1) # [B, N+1]
733
+ per_seq_loss = -jnp.einsum("bn,bn->b", logalpha_phi_last, one_hot)
734
+
735
+ if loss_reduction == "mean":
736
+ target_lengths = labelpaddings.shape[-1] - labelpaddings.sum(axis=-1)
737
+ loss = (per_seq_loss / target_lengths).mean()
738
+ elif loss_reduction == "sum":
739
+ loss = per_seq_loss.sum()
740
+ else:
741
+ loss = per_seq_loss
742
+
743
+ if not output_emission_dict:
744
+ return loss
745
+
746
+ return loss, {
747
+ "logalpha_phi": logalpha_phi,
748
+ "logalpha_emit": logalpha_emit,
749
+ "logprobs_phi": logprobs_phi,
750
+ "logprobs_emit": logprobs_emit,
751
+ }
752
+
753
+
754
+ def make_dataset(data_args, seed=42):
755
+ # Pre-processing dataset
756
+ import re
757
+
758
+ def map_nst(entry):
759
+ text = entry["text"].lower()
760
+ text = text.replace("(...vær stille under dette opptaket...)", "")
761
+ text = re.sub('[áàâ]', 'a', text)
762
+ text = re.sub('[ä]', 'æ', text)
763
+ text = re.sub('[éèëê]', 'e', text)
764
+ text = re.sub('[íìïî]', 'i', text)
765
+ text = re.sub('[óòöô]', 'o', text)
766
+ text = re.sub('[ö]', 'ø', text)
767
+ text = re.sub('[ç]', 'c', text)
768
+ text = re.sub('[úùüû]', 'u', text)
769
+ # text = re.sub('\\(?=(Punktum|Komma|Utropstegn|Spørsmålstegn))', ' ', text)
770
+ text = re.sub('\s+', ' ', text)
771
+ return {"text": text}
772
+
773
+ def filter_nst(entry):
774
+ if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)):
775
+ return False # Too short
776
+ if re.match(entry["type"], "pIW|CA"):
777
+ return False # Spelling out words
778
+ return True
779
+
780
+ def filter_npsc(entry):
781
+ # False if there are digits in the text
782
+ if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)):
783
+ return False # Too short
784
+ if re.search("\d", entry["text"]):
785
+ return False
786
+ return True
787
+
788
+ def map_npsc(entry):
789
+ batch = {"text": entry["text"].lower()}
790
+ batch["text"] = re.sub('[áàâ]', 'a', batch["text"])
791
+ batch["text"] = re.sub('[ä]', 'æ', batch["text"])
792
+ batch["text"] = re.sub('[éèëê]', 'e', batch["text"])
793
+ batch["text"] = re.sub('[íìïî]', 'i', batch["text"])
794
+ batch["text"] = re.sub('[óòöô]', 'o', batch["text"])
795
+ batch["text"] = re.sub('[ö]', 'ø', batch["text"])
796
+ batch["text"] = re.sub('[ç]', 'c', batch["text"])
797
+ batch["text"] = re.sub('[úùüû]', 'u', batch["text"])
798
+ batch["text"] = re.sub('\s', ' ', batch["text"])
799
+ batch["text"] = re.sub('<ee>', 'eee', batch["text"])
800
+ batch["text"] = re.sub('<qq>', 'qqq', batch["text"])
801
+ batch["text"] = re.sub('<mm>', 'mmm', batch["text"])
802
+ batch["text"] = re.sub('<inaudible>', 'xxx', batch["text"])
803
+ # batch["text"] = re.sub('<inaudible>', '?', batch["text"])
804
+ if "<" in batch["text"]:
805
+ raise ValueError(batch["text"])
806
+ return batch
807
+
808
+ nst = datasets.load_dataset("NbAiLab/NST", "no-close")
809
+ npsc = datasets.load_dataset("NbAiLab/NPSC", "16K_mp3")
810
+ # TODO NST_hesitate
811
+
812
+ split = len(npsc[data_args.train_split_name]) / (len(npsc[data_args.train_split_name]) + len(npsc[data_args.eval_split_name])) # Use same train/val ratio as NPSC
813
+ nst_train = nst[data_args.train_split_name].train_test_split(train_size=split, seed=seed)
814
+ nst[data_args.train_split_name] = nst_train["train"]
815
+ nst[data_args.eval_split_name] = nst_train["test"]
816
+
817
+ nst = nst.filter(filter_nst).map(
818
+ map_nst,
819
+ num_proc=data_args.preprocessing_num_workers,
820
+ desc="filtering NST",
821
+ ).shuffle(seed=seed)
822
+ npsc = npsc.filter(filter_npsc).map(
823
+ map_npsc,
824
+ num_proc=data_args.preprocessing_num_workers,
825
+ desc="filtering NPSC",
826
+ ).shuffle(seed=seed)
827
+
828
+ npsc_base = npsc.remove_columns([col for col in npsc[data_args.train_split_name].column_names if col not in ["text", "audio"]])
829
+ nst_base = nst.remove_columns([col for col in nst[data_args.train_split_name].column_names if col not in ["text", "audio"]])
830
+
831
+ combined = {}
832
+ for split in data_args.train_split_name, data_args.eval_split_name, data_args.test_split_name:
833
+ probs = np.array([len(nst_base[split]), len(npsc_base[split])]) # Weight by number of examples
834
+ probs = (probs / probs.sum()).tolist()
835
+ comb = datasets.interleave_datasets([nst_base[split], npsc_base[split]], probabilities=probs, seed=seed)
836
+ combined[split] = comb
837
+
838
+ return datasets.DatasetDict(**combined)
839
+
840
+ def main():
841
+ # 1. Parse input arguments
842
+ # See all possible arguments in src/transformers/training_args.py
843
+ # or by passing the --help flag to this script.
844
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
845
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FlaxTrainingArguments))
846
+
847
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
848
+ # If we pass only one argument to the script and it's the path to a json file,
849
+ # let's parse it to get our arguments.
850
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
851
+ else:
852
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
853
+
854
+ # 2. Setup logging
855
+ # Make one log on every process with the configuration for debugging.
856
+ logging.basicConfig(
857
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
858
+ datefmt="%m/%d/%Y %H:%M:%S",
859
+ handlers=[logging.StreamHandler(sys.stdout)],
860
+ )
861
+ # Set the verbosity to info of the Transformers logger.
862
+ # We only want one process per machine to log things on the screen.
863
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
864
+ if jax.process_index() == 0:
865
+ datasets.utils.logging.set_verbosity_warning()
866
+ transformers.utils.logging.set_verbosity_info()
867
+ else:
868
+ datasets.utils.logging.set_verbosity_error()
869
+ transformers.utils.logging.set_verbosity_error()
870
+
871
+ # Set up wandb run
872
+ if jax.process_index() == 0:
873
+ wandb.init(project=data_args.wandb_project, name=data_args.wandb_name, job_type=data_args.wandb_job_type)
874
+
875
+ logger.info("Training/evaluation parameters %s", training_args)
876
+
877
+ # Set the default TPU matmul precision and display the number of devices
878
+ jax.config.update("jax_default_matmul_precision", training_args.matmul_precision)
879
+ logger.info(f"JAX devices: {jax.device_count()}, matmul precision: {training_args.matmul_precision}")
880
+
881
+ # 4. Load dataset
882
+
883
+ set_seed(training_args.seed)
884
+ raw_datasets = make_dataset(data_args, seed=training_args.seed)
885
+
886
+ # raw_datasets = DatasetDict()
887
+
888
+ # if training_args.do_train:
889
+ # raw_datasets[data_args.train_split_name] = load_dataset(
890
+ # data_args.dataset_name,
891
+ # data_args.dataset_config_name,
892
+ # split=data_args.train_split_name,
893
+ # cache_dir=data_args.dataset_cache_dir,
894
+ # use_auth_token=True if model_args.use_auth_token else None,
895
+ # )
896
+
897
+ # if training_args.do_eval:
898
+ # raw_datasets[data_args.eval_split_name] = load_dataset(
899
+ # data_args.dataset_name,
900
+ # data_args.dataset_config_name,
901
+ # split=data_args.eval_split_name,
902
+ # cache_dir=data_args.dataset_cache_dir,
903
+ # use_auth_token=True if model_args.use_auth_token else None,
904
+ # )
905
+
906
+ # if training_args.do_predict:
907
+ # test_split = data_args.test_split_name.split("+")
908
+ # for split in test_split:
909
+ # raw_datasets[split] = load_dataset(
910
+ # data_args.dataset_name,
911
+ # data_args.dataset_config_name,
912
+ # split=split,
913
+ # cache_dir=data_args.dataset_cache_dir,
914
+ # use_auth_token=True if model_args.use_auth_token else None,
915
+ # )
916
+
917
+ if not training_args.do_train and not training_args.do_eval and not training_args.do_predict:
918
+ raise ValueError(
919
+ "Cannot not train, not do evaluation and not do prediction. At least one of "
920
+ "training, evaluation or prediction has to be done."
921
+ )
922
+
923
+ # if not training, there is no need to run multiple epochs
924
+ if not training_args.do_train:
925
+ training_args.num_train_epochs = 1
926
+
927
+ if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
928
+ raise ValueError(
929
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
930
+ "Make sure to set `--audio_column_name` to the correct audio column - one of "
931
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
932
+ )
933
+
934
+ if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names:
935
+ raise ValueError(
936
+ f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
937
+ "Make sure to set `--text_column_name` to the correct text column - one of "
938
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
939
+ )
940
+
941
+ # 5. Load pretrained model, tokenizer, and feature extractor
942
+ #
943
+ # Distributed training:
944
+ # The .from_pretrained methods guarantee that only one local process can concurrently
945
+ config = Wav2Vec2Config.from_pretrained(
946
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
947
+ cache_dir=model_args.cache_dir,
948
+ revision=model_args.model_revision,
949
+ use_auth_token=True if model_args.use_auth_token else None,
950
+ )
951
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
952
+ model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
953
+ cache_dir=model_args.cache_dir,
954
+ revision=model_args.model_revision,
955
+ use_auth_token=True if model_args.use_auth_token else None,
956
+ )
957
+ tokenizer = AutoTokenizer.from_pretrained(
958
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
959
+ cache_dir=model_args.cache_dir,
960
+ revision=model_args.model_revision,
961
+ use_auth_token=True if model_args.use_auth_token else None,
962
+ )
963
+ # update config according to training args, model args, and tokenizer attributes
964
+ config.update(
965
+ {
966
+ "feat_proj_dropout": model_args.feat_proj_dropout,
967
+ "attention_dropout": model_args.attention_dropout,
968
+ "hidden_dropout": model_args.hidden_dropout,
969
+ "final_dropout": model_args.final_dropout,
970
+ "mask_time_prob": model_args.mask_time_prob,
971
+ "mask_time_length": model_args.mask_time_length,
972
+ "mask_feature_prob": model_args.mask_feature_prob,
973
+ "mask_feature_length": model_args.mask_feature_length,
974
+ "gradient_checkpointing": training_args.gradient_checkpointing,
975
+ "layerdrop": model_args.layerdrop,
976
+ "ctc_loss_reduction": model_args.ctc_loss_reduction,
977
+ "ctc_zero_infinity": model_args.ctc_zero_infinity,
978
+ "pad_token_id": tokenizer.pad_token_id,
979
+ "vocab_size": tokenizer.vocab_size, # len(tokenizer),
980
+ "activation_dropout": model_args.activation_dropout,
981
+ }
982
+ )
983
+
984
+ if tokenizer.do_lower_case and data_args.dataset_name != "librispeech_asr":
985
+ raise ValueError(
986
+ "Setting the tokenizer attribute `do_lower_case` to `True` converts all input strings to "
987
+ "uppercase prior to tokenization. This should only be done when the tokenizer is built on an uppercased corpus,"
988
+ "i.e. for the dataset `librispeech_asr` only. If your dataset is not `librispeech_asr`, the tokenizer is mostly likely "
989
+ "built on an lowercased corpus. In this case, set `tokenizer.do_lower_case` to ``False`."
990
+ )
991
+
992
+ if training_args.precision == "full_mixed":
993
+ dtype = jnp.bfloat16
994
+ training_args.mixed_precision = True
995
+ elif training_args.precision == "half_mixed":
996
+ dtype = jnp.bfloat16
997
+ training_args.mixed_precision = False
998
+ else:
999
+ dtype = jnp.float32
1000
+ training_args.mixed_precision = False
1001
+
1002
+ try:
1003
+ model = FlaxWav2Vec2ForCTC.from_pretrained(
1004
+ model_args.model_name_or_path,
1005
+ config=config,
1006
+ dtype=dtype,
1007
+ cache_dir=model_args.cache_dir,
1008
+ revision=model_args.model_revision,
1009
+ use_auth_token=True if model_args.use_auth_token else None,
1010
+ )
1011
+ except:
1012
+ model = FlaxWav2Vec2ForCTC.from_pretrained(
1013
+ model_args.model_name_or_path,
1014
+ config=config,
1015
+ dtype=dtype,
1016
+ cache_dir=model_args.cache_dir,
1017
+ revision=model_args.model_revision,
1018
+ use_auth_token=True if model_args.use_auth_token else None,
1019
+ from_pt=True,
1020
+ )
1021
+
1022
+ # 6. Resample speech dataset ALWAYS
1023
+ raw_datasets = raw_datasets.cast_column(
1024
+ data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
1025
+ )
1026
+
1027
+ # 7. Preprocessing the datasets.
1028
+ # We need to read the audio files as arrays and tokenize the targets.
1029
+ max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
1030
+ min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate)
1031
+ max_target_length = data_args.max_label_length
1032
+ min_target_length = data_args.min_label_length
1033
+ pad_input_to_multiple_of = data_args.pad_input_to_multiple_of
1034
+ audio_column_name = data_args.audio_column_name
1035
+ num_workers = data_args.preprocessing_num_workers
1036
+ text_column_name = data_args.text_column_name
1037
+ model_input_name = feature_extractor.model_input_names[0]
1038
+ do_lower_case = data_args.do_lower_case
1039
+ dataset_name = data_args.dataset_name
1040
+ chars_to_ignore = ', ? . ! - ; : " “ % ‘ ” ?'.split(" ")
1041
+ chars_to_ignore_regex = f'[{"".join(chars_to_ignore)}]'
1042
+ # gigaspeech_punctuation = {" <comma>": ",", " <period>": ".", " <questionmark>": "?", " <exclamationpoint>": "!"}
1043
+ # gigaspeech_disfluencies = ["<other>", "<sil>"]
1044
+ # swb_disfluencies = ["[noise]", "[laughter]", "[silence]", "<a_aside>", "<b_aside>", "<e_aside>", "[laughter-",
1045
+ # "[vocalized-noise]", "_1"]
1046
+ # swb_punctuations = ["{", "}", "[", "]-", "]"]
1047
+ # earnings_disfluencies = ["<crosstalk>", "<affirmative>", "<inaudible>", "inaudible", "<laugh>", "<unk>"]
1048
+ ignore_segments = ["ignore_time_segment_in_scoring", "<noise>", "<music>", "[noise]", "[laughter]", "[silence]",
1049
+ "[vocalized-noise]", "<crosstalk>", "<affirmative>", "<inaudible>", "<laugh>", "<other>", "<sil>", ""]
1050
+
1051
+ if training_args.do_train and data_args.max_train_samples is not None:
1052
+ raw_datasets[data_args.train_split_name] = raw_datasets[data_args.train_split_name].select(range(data_args.max_train_samples))
1053
+
1054
+ if training_args.do_eval and data_args.max_eval_samples is not None:
1055
+ raw_datasets[data_args.eval_split_name] = raw_datasets[data_args.eval_split_name].select(range(data_args.max_eval_samples))
1056
+
1057
+ if training_args.do_predict and data_args.max_test_samples is not None:
1058
+ raw_datasets[data_args.test_split_name] = raw_datasets[data_args.test_split_name].select(range(data_args.max_eval_samples))
1059
+
1060
+ if training_args.do_train and data_args.remove_punctuation:
1061
+
1062
+ def remove_punctuation(batch):
1063
+ batch[text_column_name] = (
1064
+ re.sub(chars_to_ignore_regex, "", batch[text_column_name]).replace("'", "").replace('"', "")
1065
+ )
1066
+
1067
+ raw_datasets[data_args.train_split_name] = raw_datasets[data_args.train_split_name].map(
1068
+ remove_punctuation,
1069
+ num_proc=data_args.preprocessing_num_workers,
1070
+ desc="removing punctuation from train split",
1071
+ )
1072
+
1073
+ # filter data where the targets are ignored in scoring
1074
+ def is_target_labels(input_str):
1075
+ return input_str.lower() not in ignore_segments
1076
+
1077
+ raw_datasets = raw_datasets.filter(
1078
+ is_target_labels,
1079
+ num_proc=num_workers,
1080
+ input_columns=[text_column_name],
1081
+ desc="filtering data where the targets are ignored in scoring",
1082
+ )
1083
+
1084
+ def prepare_dataset(batch):
1085
+ # process audio
1086
+ try:
1087
+ sample = batch[audio_column_name]
1088
+ except ValueError:
1089
+ sample = {"array": np.array([0.]), "sampling_rate": feature_extractor.sampling_rate}
1090
+ inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
1091
+ # process audio length
1092
+ batch[model_input_name] = inputs.input_values[0]
1093
+ batch["input_length"] = len(batch["input_values"])
1094
+
1095
+ # process targets
1096
+ input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
1097
+
1098
+ # if dataset_name == "google/xtreme_s":
1099
+ # # Finally, we tokenize the processed text
1100
+ # batch["labels"] = tokenizer(input_str).input_ids
1101
+ # batch["labels_length"] = len(batch["labels"])
1102
+ # return batch
1103
+
1104
+ # # Common Voice 9
1105
+ # if input_str.startswith('"') and input_str.endswith('"'):
1106
+ # # we can remove trailing quotation marks as they do not affect the transcription
1107
+ # input_str = input_str[1:-1]
1108
+ # # normalize quotation marks
1109
+ # input_str = re.sub(r'["“”]', '"', input_str)
1110
+ # # normalize apostrophes
1111
+ # input_str = re.sub(r"[’']", "'", input_str)
1112
+ # # normalize hyphens
1113
+ # input_str = re.sub(r"[—–]", "-", input_str)
1114
+ # # replace double quotation marks with single
1115
+ # input_str = input_str.replace('""', '"')
1116
+ # if dataset_name == "mozilla-foundation/common_voice_9_0" and len(input_str):
1117
+ # # for CV9, we'll normalize the text to always finish with punctuation
1118
+ # if input_str[-1] not in [".", "?", "!"]:
1119
+ # input_str = input_str + "."
1120
+
1121
+ # # TEDLIUM-3
1122
+ # # delete the <unk> token from the text and replace spaced apostrophes with un-spaced
1123
+ # input_str = input_str.replace("<unk>", "").replace(" '", "'")
1124
+
1125
+ # # GigaSpeech
1126
+ # for disfluency in gigaspeech_disfluencies:
1127
+ # input_str = input_str.replace(disfluency, "")
1128
+ # # convert spelled out punctuation to symbolic form
1129
+ # for punctuation, replacement in gigaspeech_punctuation.items():
1130
+ # input_str = input_str.replace(punctuation, replacement)
1131
+ # if dataset_name == "speechcolab/gigaspeech" and len(input_str):
1132
+ # # for GS, we'll normalize the text to always finish with punctuation
1133
+ # if input_str[-1] not in [".", "?", "!"]:
1134
+ # input_str = input_str + "."
1135
+
1136
+ # # SWB
1137
+ # for disfluency in swb_disfluencies:
1138
+ # input_str = input_str.replace(disfluency, "")
1139
+ # # remove parenthesised text (test data only)
1140
+ # input_str = re.sub("[\(].*?[\)]", "", input_str)
1141
+ # for punctuation in swb_punctuations:
1142
+ # input_str = input_str.replace(punctuation, "")
1143
+ # # replace anomalous words with their correct transcriptions
1144
+ # split_str = input_str.split("/")
1145
+ # if len(split_str) > 1:
1146
+ # input_str = " ".join(
1147
+ # [" ".join([" ".join(i.split(" ")[:-1]) for i in split_str])] + [split_str[-1].split(" ")[-1]])
1148
+
1149
+ # # Earnings 22
1150
+ # for disfluency in earnings_disfluencies:
1151
+ # input_str = input_str.replace(disfluency, "")
1152
+ # # replace mal-formatted ellipsis
1153
+ # input_str = input_str.replace("…", ".")
1154
+
1155
+ # JIWER compliance
1156
+ # remove multiple spaces
1157
+ input_str = re.sub(r"\s\s+", " ", input_str)
1158
+ # strip trailing spaces
1159
+ input_str = input_str.strip()
1160
+
1161
+ # Finally, we tokenize the processed text
1162
+ batch["labels"] = tokenizer(input_str).input_ids
1163
+ batch["labels_length"] = len(batch["labels"])
1164
+ return batch
1165
+
1166
+ vectorized_datasets = raw_datasets.map(
1167
+ prepare_dataset,
1168
+ remove_columns=next(iter(raw_datasets.values())).column_names,
1169
+ num_proc=num_workers,
1170
+ desc="preprocess dataset",
1171
+ )
1172
+
1173
+ # filter data with inputs shorter than min_input_length or longer than max_input_length
1174
+ def is_audio_in_length_range(length):
1175
+ return length > min_input_length and length < max_input_length
1176
+
1177
+ vectorized_datasets = vectorized_datasets.filter(
1178
+ is_audio_in_length_range,
1179
+ num_proc=num_workers,
1180
+ input_columns=["input_length"],
1181
+ )
1182
+
1183
+ # filter data with targets shorter than min_target_length or longer than max_target_length
1184
+ def is_labels_in_length_range(length):
1185
+ return length > min_target_length # and length < max_target_length
1186
+
1187
+ vectorized_datasets = vectorized_datasets.filter(
1188
+ is_labels_in_length_range,
1189
+ num_proc=num_workers,
1190
+ input_columns=["labels_length"],
1191
+ )
1192
+
1193
+ # for large datasets it is advised to run the preprocessing on a
1194
+ # single machine first with `args.preprocessing_only` since there will mostly likely
1195
+ # be a timeout when running the script in distributed mode.
1196
+ # In a second step `args.preprocessing_only` can then be set to `False` to load the
1197
+ # cached dataset
1198
+ if data_args.preprocessing_only:
1199
+ cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
1200
+ logger.info(f"Data preprocessing finished. Files cached at {cache}.")
1201
+ return
1202
+
1203
+ # 8. Load Metrics
1204
+ wer_metric = load_metric("wer")
1205
+ cer_metric = load_metric("cer")
1206
+
1207
+ def compute_metrics(pred_ids: List[List[int]], label_ids: List[List[int]]):
1208
+ padded_ids = np.where(np.asarray(label_ids) == -100, tokenizer.pad_token_id, np.asarray(label_ids))
1209
+
1210
+ pred_str = tokenizer.batch_decode(pred_ids)
1211
+ # we do not want to group tokens when computing the metrics
1212
+ label_str = tokenizer.batch_decode(padded_ids, group_tokens=False)
1213
+
1214
+ wer = wer_metric.compute(predictions=pred_str, references=label_str)
1215
+ cer = cer_metric.compute(predictions=pred_str, references=label_str)
1216
+
1217
+ return {"wer": wer, "cer": cer}, pred_str, label_str
1218
+
1219
+ # 9. save feature extractor, tokenizer and config
1220
+ feature_extractor.save_pretrained(training_args.output_dir)
1221
+ tokenizer.save_pretrained(training_args.output_dir)
1222
+ config.save_pretrained(training_args.output_dir)
1223
+
1224
+ processor = AutoProcessor.from_pretrained(training_args.output_dir)
1225
+
1226
+ data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding(
1227
+ processor=processor,
1228
+ input_padding="longest",
1229
+ pad_input_to_multiple_of=pad_input_to_multiple_of,
1230
+ max_label_length=data_args.max_label_length,
1231
+ )
1232
+
1233
+ # Enable tensorboard only on the master node
1234
+ has_tensorboard = is_tensorboard_available()
1235
+ if has_tensorboard and jax.process_index() == 0:
1236
+ try:
1237
+ from flax.metrics.tensorboard import SummaryWriter
1238
+
1239
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
1240
+ except ImportError as ie:
1241
+ has_tensorboard = False
1242
+ logger.warning(
1243
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
1244
+ )
1245
+ else:
1246
+ logger.warning(
1247
+ "Unable to display metrics through TensorBoard because the package is not installed: "
1248
+ "Please run `pip install tensorboard` to enable."
1249
+ )
1250
+
1251
+ # 10. Handle the repository creation
1252
+ if training_args.push_to_hub:
1253
+ with open(os.path.join(training_args.output_dir, ".gitattributes"), "r+") as f:
1254
+ git_lfs_extensions = f.read()
1255
+ if "*.wandb" not in git_lfs_extensions:
1256
+ f.write("*.wandb filter=lfs diff=lfs merge=lfs -text")
1257
+ if training_args.hub_model_id is None:
1258
+ repo_name = get_full_repo_name(
1259
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
1260
+ )
1261
+ else:
1262
+ repo_name = training_args.hub_model_id
1263
+ repo = Repository(training_args.output_dir, clone_from=repo_name)
1264
+
1265
+ # 11. Initialize our training
1266
+ rng = jax.random.PRNGKey(training_args.seed)
1267
+ rng, dropout_rng = jax.random.split(rng)
1268
+
1269
+ # Store some constants
1270
+ max_steps = int(training_args.max_steps)
1271
+ gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
1272
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
1273
+ batch_size_per_update = train_batch_size * gradient_accumulation_steps
1274
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
1275
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
1276
+ to_dtype = to_bf16 if training_args.mixed_precision else to_fp32
1277
+
1278
+ if training_args.do_train:
1279
+ num_train_samples = len(vectorized_datasets[data_args.train_split_name])
1280
+ steps_per_epoch = num_train_samples // batch_size_per_update
1281
+ if max_steps > 0:
1282
+ num_epochs = -(training_args.max_steps // -steps_per_epoch)
1283
+ total_train_steps = max_steps
1284
+ else:
1285
+ num_epochs = int(training_args.num_train_epochs)
1286
+ total_train_steps = steps_per_epoch * num_epochs
1287
+
1288
+ # Create learning rate schedule
1289
+ # Create learning rate schedule
1290
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
1291
+ total_train_steps,
1292
+ training_args.warmup_steps,
1293
+ training_args.learning_rate,
1294
+ )
1295
+
1296
+ # We use Optax's "masking" functionality to not apply weight decay
1297
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
1298
+ # mask boolean with the same structure as the parameters.
1299
+ # The mask is True for parameters that should be decayed.
1300
+ # Note that this mask is specifically adapted for FlaxWav2Vec2 and FlaxBart.
1301
+ # For FlaxT5, one should correct the layer norm parameter naming
1302
+ # accordingly - see `run_t5_mlm_flax.py` e.g.
1303
+ def decay_mask_fn(params):
1304
+ flat_params = traverse_util.flatten_dict(params)
1305
+ layer_norm_params = [
1306
+ (name, "scale")
1307
+ for name in ["layer_norm", "self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
1308
+ ]
1309
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
1310
+ return traverse_util.unflatten_dict(flat_mask)
1311
+
1312
+ if training_args.adafactor:
1313
+ # Create Adafactor optimizer
1314
+ optim = optax.adafactor(
1315
+ learning_rate=linear_decay_lr_schedule_fn,
1316
+ dtype_momentum=jnp.bfloat16 if training_args.mixed_precision else jnp.float32,
1317
+ weight_decay_rate=training_args.weight_decay,
1318
+ weight_decay_mask=decay_mask_fn,
1319
+ )
1320
+ else:
1321
+ # Create AdamW optimizer
1322
+ optim = optax.adamw(
1323
+ learning_rate=linear_decay_lr_schedule_fn,
1324
+ b1=training_args.adam_beta1,
1325
+ b2=training_args.adam_beta2,
1326
+ eps=training_args.adam_epsilon,
1327
+ weight_decay=training_args.weight_decay,
1328
+ mask=decay_mask_fn,
1329
+ )
1330
+
1331
+ # Optax MultiSteps for gradient accumulation. We'll only call this optimizer transformation if gradient accumulation is required (i.e. gradient accumulation steps > 1)
1332
+ if training_args.multisteps and gradient_accumulation_steps > 1:
1333
+ optim = optax.MultiSteps(optim, gradient_accumulation_steps, use_grad_mean=False)
1334
+ else:
1335
+ num_epochs = 0
1336
+ total_train_steps = 0
1337
+ num_train_samples = 0
1338
+ optim = None
1339
+
1340
+ # Setup train state
1341
+ state = MixedPrecisionTrainState.create(
1342
+ apply_fn=model.__call__,
1343
+ get_attention_mask_fn=model._get_feature_vector_attention_mask,
1344
+ params=model.params,
1345
+ tx=optim,
1346
+ to_dtype=to_dtype,
1347
+ dropout_rng=dropout_rng,
1348
+ max_grad_norm=training_args.max_grad_norm,
1349
+ )
1350
+
1351
+ # Replicate the train state on each device
1352
+ state = state.replicate()
1353
+ blank_id = model.config.pad_token_id
1354
+
1355
+ # Define gradient update step fn
1356
+ def train_step(state, batch):
1357
+ # only one single rng per grad step, with or without accumulation, as the graph should be identical over one effective training batch
1358
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
1359
+
1360
+ def compute_loss(params, minibatch):
1361
+ labels = minibatch.pop("labels")
1362
+ logits = state.apply_fn(
1363
+ **minibatch,
1364
+ params=params,
1365
+ dropout_rng=dropout_rng,
1366
+ freeze_feature_encoder=model_args.freeze_feature_encoder,
1367
+ train=True,
1368
+ )[0]
1369
+ logits_mask = state.get_attention_mask_fn(logits.shape[1], batch["attention_mask"])
1370
+ loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean")
1371
+
1372
+ return loss
1373
+
1374
+ grad_fn = jax.value_and_grad(compute_loss)
1375
+
1376
+ if gradient_accumulation_steps == 1 or training_args.multisteps:
1377
+ loss, grad = grad_fn(to_dtype(state.params), batch)
1378
+
1379
+ # Custom gradient accumulation
1380
+ else:
1381
+ # add a first dimension over gradient_accumulation_steps for minibatch slices
1382
+ batch = jax.tree_util.tree_map(
1383
+ lambda x: x.reshape(
1384
+ gradient_accumulation_steps, training_args.per_device_train_batch_size, *x.shape[1::]
1385
+ ),
1386
+ batch,
1387
+ )
1388
+
1389
+ def accum_minibatch_step(accum_grad, minibatch):
1390
+ # compute loss, num labels and grad over minibatch and accumulate
1391
+ loss, grad = grad_fn(to_dtype(state.params), minibatch)
1392
+ return jax.tree_util.tree_map(jnp.add, accum_grad, grad), loss
1393
+
1394
+ # create an initial state for accumulating losses, num labels and gradients
1395
+ init_grad = jax.tree_util.tree_map(jnp.zeros_like, to_dtype(state.params))
1396
+ # loop accum minibatch step over the number of gradient accumulation steps
1397
+ grad, loss = jax.lax.scan(accum_minibatch_step, init_grad, batch)
1398
+
1399
+ # update state
1400
+ new_state = state.apply_gradients(
1401
+ grads=grad,
1402
+ dropout_rng=new_dropout_rng,
1403
+ to_dtype=to_dtype,
1404
+ )
1405
+
1406
+ # compute gradient norms over all layers and globally for detailed monitoring
1407
+ layer_grad_norm = jax.tree_util.tree_map(jnp.linalg.norm, grad)
1408
+ logs = {
1409
+ "layer_grad_norm": layer_grad_norm,
1410
+ "grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm)),
1411
+ }
1412
+
1413
+ # compute parameter norms over all layers and globally for detailed monitoring
1414
+ layer_param_norm = jax.tree_util.tree_map(jnp.linalg.norm, new_state.params)
1415
+ logs["layer_param_norm"] = layer_param_norm
1416
+ logs["param_norm"] = jnp.linalg.norm(jax.tree_util.tree_leaves(layer_param_norm))
1417
+
1418
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
1419
+ metrics.update(logs)
1420
+
1421
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1422
+ # metrics = to_fp32(metrics)
1423
+
1424
+ return new_state, metrics
1425
+
1426
+ # Define eval fn
1427
+ def eval_step(params, batch):
1428
+ labels = batch.pop("labels")
1429
+ logits = model(**batch, params=params, train=False)[0]
1430
+
1431
+ logits_mask = model._get_feature_vector_attention_mask(logits.shape[1], batch["attention_mask"])
1432
+ loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean")
1433
+
1434
+ pred_ids = jnp.argmax(logits, axis=-1)
1435
+
1436
+ # summarize metrics
1437
+ metrics = {"loss": loss}
1438
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1439
+ # metrics = to_fp32(metrics)
1440
+ return metrics, pred_ids
1441
+
1442
+ # Create parallel version of the train and eval step
1443
+ if training_args.do_train:
1444
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
1445
+
1446
+ if training_args.do_eval:
1447
+ p_eval_step = jax.pmap(eval_step, "batch")
1448
+
1449
+ def run_evaluation(step):
1450
+ if training_args.do_eval:
1451
+ # ======================== Evaluating ==============================
1452
+ eval_metrics = []
1453
+ eval_preds = []
1454
+ eval_labels = []
1455
+
1456
+ # Generate eval set by sequentially sampling indices from the eval dataset and grouping by length
1457
+ eval_samples_idx = get_grouped_indices(vectorized_datasets[data_args.eval_split_name], eval_batch_size)
1458
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
1459
+
1460
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
1461
+ samples = [vectorized_datasets[data_args.eval_split_name][int(idx)] for idx in batch_idx]
1462
+ batch = data_collator(samples)
1463
+ labels = batch["labels"]
1464
+
1465
+ metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1466
+ eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1])))
1467
+ eval_metrics.append(metrics)
1468
+
1469
+ eval_labels.extend(labels)
1470
+
1471
+ # normalize eval metrics
1472
+ eval_metrics = get_metrics(eval_metrics)
1473
+ eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
1474
+ eval_metrics = to_fp32(eval_metrics)
1475
+
1476
+ # always run compute metrics
1477
+ error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels)
1478
+ eval_metrics.update(error_rate_metric)
1479
+ error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()])
1480
+
1481
+ # Print metrics and update progress bar
1482
+ desc = f"Step... ({step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})"
1483
+ epochs.write(desc)
1484
+ epochs.desc = desc
1485
+
1486
+ # Save metrics
1487
+ write_wandb_log(eval_metrics, step, prefix="eval")
1488
+ write_wandb_pred(pred_str, label_str, step)
1489
+ # if has_tensorboard and jax.process_index() == 0:
1490
+ # write_eval_metric(summary_writer, eval_metrics, step, pred_str=pred_str)
1491
+
1492
+ def save_checkpoint(step):
1493
+ # save and push checkpoint to the hub
1494
+ if jax.process_index() == 0:
1495
+ params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
1496
+ model.save_pretrained(training_args.output_dir, params=params)
1497
+ tokenizer.save_pretrained(training_args.output_dir)
1498
+ if training_args.push_to_hub:
1499
+ repo.push_to_hub(commit_message=f"{wandb.run.id}: saving weights and logs of step {int(step / 1000)}k", blocking=False)
1500
+
1501
+ skip_epochs = data_args.skip_steps // (num_train_samples // batch_size_per_update)
1502
+ logger.info("***** Running training *****")
1503
+ logger.info(f" Num examples = {num_train_samples}")
1504
+ logger.info(f" Num Epochs = {num_epochs}")
1505
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
1506
+ logger.info(f" Num gradient accumulation steps = {gradient_accumulation_steps}")
1507
+ logger.info(f" Total train batch size (w. parallel & distributed) = {batch_size_per_update}")
1508
+ logger.info(f" Total optimization steps = {total_train_steps}")
1509
+ logger.info(f" Gradient checkpointing: {config.gradient_checkpointing}")
1510
+ logger.info(f" Use scan: {config.use_scan}")
1511
+ logger.info(f" Fuse matmuls: {config.fuse_matmuls}")
1512
+ logger.info(f" Skipping: {data_args.skip_steps} steps ({skip_epochs} epochs)")
1513
+
1514
+ train_time = cur_step = 0
1515
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
1516
+ for epoch in epochs:
1517
+ if training_args.do_train:
1518
+ # ======================== Training ================================
1519
+ train_start = time.time()
1520
+
1521
+ if epoch < skip_epochs:
1522
+ logger.info(f"Skipping epoch {epoch + 1}")
1523
+ continue
1524
+
1525
+ # Create sampling rng
1526
+ rng, input_rng = jax.random.split(rng)
1527
+
1528
+ # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length
1529
+ train_samples_idx = get_grouped_indices(vectorized_datasets[data_args.train_split_name], batch_size_per_update, input_rng)
1530
+ train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update)
1531
+
1532
+ if data_args.skip_steps > cur_step:
1533
+ logger.info(f"Skipping {data_args.skip_steps - (epoch * (num_train_samples // batch_size_per_update))} steps...")
1534
+ # Gather the indices for creating the batch and do a training step
1535
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1), 1):
1536
+ cur_step = epoch * (num_train_samples // batch_size_per_update) + step
1537
+ if cur_step <= data_args.skip_steps:
1538
+ continue
1539
+
1540
+ samples = [vectorized_datasets[data_args.train_split_name][int(idx)] for idx in batch_idx]
1541
+ batch = data_collator(samples)
1542
+ batch = shard(batch.data)
1543
+ try:
1544
+ state, train_metric = p_train_step(state, batch)
1545
+ except TypeError as e:
1546
+ logger.warning("Encountered following error: \n", e)
1547
+
1548
+
1549
+ if cur_step % training_args.logging_steps == 0:
1550
+ # Save metrics
1551
+ train_metric = unreplicate(train_metric)
1552
+ train_time += time.time() - train_start
1553
+ # need to upcast all device arrays to fp32 for wandb logging (jnp.bfloat16 not supported) -> do this here OR in train_step
1554
+ write_wandb_log(to_fp32(train_metric), cur_step, prefix=data_args.train_split_name)
1555
+ # we won't log to tensorboard for now (it is fiddly logging param and grad norms on a layer-by-layer basis)
1556
+ # if has_tensorboard and jax.process_index() == 0:
1557
+ # write_train_metric(summary_writer, train_metrics, train_time, cur_step)
1558
+
1559
+ epochs.write(
1560
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']}, Gradient Norm: {train_metric['grad_norm']})"
1561
+ )
1562
+
1563
+ if cur_step % total_train_steps == 0:
1564
+ break
1565
+
1566
+ if training_args.eval_steps and cur_step % training_args.eval_steps == 0:
1567
+ run_evaluation(cur_step)
1568
+
1569
+ if cur_step % training_args.save_steps == 0:
1570
+ save_checkpoint(cur_step)
1571
+
1572
+ if training_args.eval_steps == 0 and (epoch + 1) != num_epochs:
1573
+ # run evaluation at the end of the epoch if eval steps are not specified
1574
+ run_evaluation(cur_step)
1575
+ save_checkpoint(cur_step)
1576
+
1577
+ if training_args.do_train:
1578
+ save_checkpoint(cur_step)
1579
+
1580
+ cur_step = max_steps if max_steps > 0 else cur_step # set step to max steps so that eval happens in alignment with training
1581
+
1582
+ if training_args.do_eval:
1583
+ run_evaluation(cur_step)
1584
+
1585
+ # TODO: collapse 'do_predict' into the run_evaluation function
1586
+ if training_args.do_predict:
1587
+ for split in [data_args.test_split_name]:
1588
+ # ======================== Evaluating ==============================
1589
+ eval_metrics = []
1590
+ eval_preds = []
1591
+ eval_labels = []
1592
+
1593
+ # Generate eval set by sequentially sampling indices from the test dataset and grouping by length
1594
+ eval_samples_idx = get_grouped_indices(vectorized_datasets[split], eval_batch_size)
1595
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
1596
+
1597
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc=f"Predicting {split}...", position=2)):
1598
+ samples = [vectorized_datasets[split][int(idx)] for idx in batch_idx]
1599
+ batch = data_collator(samples)
1600
+ labels = batch["labels"]
1601
+
1602
+ metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1603
+ eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1])))
1604
+ eval_metrics.append(metrics)
1605
+
1606
+ eval_labels.extend(labels)
1607
+
1608
+ # normalize eval metrics
1609
+ eval_metrics = get_metrics(eval_metrics)
1610
+ eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
1611
+ eval_metrics = to_fp32(eval_metrics)
1612
+
1613
+ # always run compute metrics
1614
+ error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels)
1615
+ eval_metrics.update(error_rate_metric)
1616
+ error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()])
1617
+
1618
+ # Print metrics and update progress bar
1619
+ desc = f"Step... ({cur_step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})"
1620
+ epochs.write(desc)
1621
+ epochs.desc = desc
1622
+
1623
+ # Save metrics
1624
+ write_wandb_log(eval_metrics, cur_step, prefix=split)
1625
+ write_wandb_pred(pred_str, label_str, cur_step, prefix=split)
1626
+ # if has_tensorboard and jax.process_index() == 0:
1627
+ # write_eval_metric(summary_writer, eval_metrics, cur_step, pred_str=pred_str)
1628
+
1629
+
1630
+ if __name__ == "__main__":
1631
+ main()
wandb/run-20220805_152536-2fzkf8n5/files/config.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_version: 1
2
+
3
+ _wandb:
4
+ desc: null
5
+ value:
6
+ cli_version: 0.12.9
7
+ code_path: code/run_flax_speech_recognition_ctc.py
8
+ framework: huggingface
9
+ huggingface_version: 4.21.0
10
+ is_jupyter_run: false
11
+ is_kaggle_kernel: false
12
+ python_version: 3.8.10
13
+ start_time: 1659713136
14
+ t:
15
+ 1:
16
+ - 1
17
+ - 2
18
+ - 3
19
+ - 11
20
+ - 12
21
+ 2:
22
+ - 1
23
+ - 2
24
+ - 3
25
+ - 11
26
+ - 12
27
+ 3:
28
+ - 13
29
+ 4: 3.8.10
30
+ 5: 0.12.9
31
+ 6: 4.21.0
32
+ 8:
33
+ - 5
wandb/run-20220805_152536-2fzkf8n5/files/diff.patch ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/config.json b/config.json
2
+ index 260219f..246b797 100644
3
+ --- a/config.json
4
+ +++ b/config.json
5
+ @@ -5,7 +5,7 @@
6
+ "add_adapter": false,
7
+ "apply_spec_augment": true,
8
+ "architectures": [
9
+ - "Wav2Vec2ForCTC"
10
+ + "Wav2Vec2ForPreTraining"
11
+ ],
12
+ "attention_dropout": 0.094,
13
+ "bos_token_id": 1,
14
+ diff --git a/run.sh b/run.sh
15
+ index 9cc498e..d5f5166 100755
16
+ --- a/run.sh
17
+ +++ b/run.sh
18
+ @@ -1,13 +1,13 @@
19
+ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_ctc.py \
20
+ --model_name_or_path="facebook/wav2vec2-xls-r-1b" \
21
+ - --hub_model_id="NbAiLab/wav2vec2-1b-npsc-nst" \
22
+ + --hub_model_id="NbAiLab/wav2vec2-1b-npsc-nst-tpu" \
23
+ --tokenizer_name="./" \
24
+ --output_dir="./" \
25
+ --overwrite_output_dir \
26
+ --num_train_epochs="40" \
27
+ --per_device_train_batch_size="2" \
28
+ --per_device_eval_batch_size="2" \
29
+ - --gradient_accumulation_steps="1" \
30
+ + --gradient_accumulation_steps="2" \
31
+ --precision="full_mixed" \
32
+ --matmul_precision="bfloat16" \
33
+ --multisteps \
34
+ @@ -16,8 +16,8 @@ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_c
35
+ --length_column_name="input_length" \
36
+ --evaluation_strategy="steps" \
37
+ --text_column_name="text" \
38
+ - --save_steps="4000" \
39
+ - --eval_steps="4000" \
40
+ + --save_steps="5000" \
41
+ + --eval_steps="5000" \
42
+ --logging_steps="100" \
43
+ --layerdrop="0.041" \
44
+ --attention_dropout="0.094" \
45
+ @@ -42,7 +42,7 @@ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_c
46
+ --ctc_zero_infinity \
47
+ --do_lower_case \
48
+ --wandb_project="wav2vec2" \
49
+ - --wandb_name="wav2vec2-1b-npsc-nst" \
50
+ + --wandb_name="wav2vec2-1b-npsc-nst-tpu" \
51
+ --remove_punctuation
52
+
53
+
54
+ diff --git a/special_tokens_map.json b/special_tokens_map.json
55
+ index 89389bf..81394dd 100644
56
+ --- a/special_tokens_map.json
57
+ +++ b/special_tokens_map.json
58
+ @@ -343,6 +343,34 @@
59
+ "rstrip": false,
60
+ "single_word": false
61
+ },
62
+ + {
63
+ + "content": "</s>",
64
+ + "lstrip": false,
65
+ + "normalized": true,
66
+ + "rstrip": false,
67
+ + "single_word": false
68
+ + },
69
+ + {
70
+ + "content": "<s>",
71
+ + "lstrip": false,
72
+ + "normalized": true,
73
+ + "rstrip": false,
74
+ + "single_word": false
75
+ + },
76
+ + {
77
+ + "content": "</s>",
78
+ + "lstrip": false,
79
+ + "normalized": true,
80
+ + "rstrip": false,
81
+ + "single_word": false
82
+ + },
83
+ + {
84
+ + "content": "<s>",
85
+ + "lstrip": false,
86
+ + "normalized": true,
87
+ + "rstrip": false,
88
+ + "single_word": false
89
+ + },
90
+ {
91
+ "content": "</s>",
92
+ "lstrip": false,
93
+ diff --git a/wandb/debug-internal.log b/wandb/debug-internal.log
94
+ index 50a0b69..6eb3065 120000
95
+ --- a/wandb/debug-internal.log
96
+ +++ b/wandb/debug-internal.log
97
+ @@ -1 +1 @@
98
+ -run-20220803_091109-yit1e59z/logs/debug-internal.log
99
+
100
+ +run-20220805_152536-2fzkf8n5/logs/debug-internal.log
101
+
102
+ diff --git a/wandb/debug.log b/wandb/debug.log
103
+ index 746223d..31f5db8 120000
104
+ --- a/wandb/debug.log
105
+ +++ b/wandb/debug.log
106
+ @@ -1 +1 @@
107
+ -run-20220803_091109-yit1e59z/logs/debug.log
108
+
109
+ +run-20220805_152536-2fzkf8n5/logs/debug.log
110
+
111
+ diff --git a/wandb/latest-run b/wandb/latest-run
112
+ index be58b40..d408175 120000
113
+ --- a/wandb/latest-run
114
+ +++ b/wandb/latest-run
115
+ @@ -1 +1 @@
116
+ -run-20220803_091109-yit1e59z
117
+
118
+ +run-20220805_152536-2fzkf8n5
119
+
wandb/run-20220805_152536-2fzkf8n5/files/output.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:40ff34b9c765c40e17b2e182bf4b3c8d4c775695283737622845e6848a90e685
3
+ size 231583
wandb/run-20220805_152536-2fzkf8n5/files/requirements.txt ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.0.0
2
+ aiohttp==3.8.1
3
+ aiosignal==1.2.0
4
+ appdirs==1.4.4
5
+ astunparse==1.6.3
6
+ async-timeout==4.0.2
7
+ attrs==21.4.0
8
+ audioread==2.1.9
9
+ backcall==0.2.0
10
+ cachetools==4.2.4
11
+ certifi==2021.10.8
12
+ cffi==1.15.1
13
+ charset-normalizer==2.0.10
14
+ chex==0.1.3
15
+ click==8.0.3
16
+ cloud-tpu-client==0.10
17
+ cloud-tpu-profiler==2.4.0
18
+ clu==0.0.6
19
+ colorama==0.4.5
20
+ commonmark==0.9.1
21
+ configparser==5.2.0
22
+ contextlib2==21.6.0
23
+ cycler==0.11.0
24
+ datasets==2.4.0
25
+ decorator==5.1.0
26
+ dill==0.3.4
27
+ dm-tree==0.1.6
28
+ docker-pycreds==0.4.0
29
+ etils==0.6.0
30
+ exceptiongroup==1.0.0rc8
31
+ filelock==3.4.2
32
+ flatbuffers==2.0
33
+ flax==0.5.3
34
+ fonttools==4.28.5
35
+ frozenlist==1.2.0
36
+ fsspec==2021.11.1
37
+ future==0.18.2
38
+ gast==0.4.0
39
+ gitdb==4.0.9
40
+ gitpython==3.1.26
41
+ google-api-core==1.31.5
42
+ google-api-python-client==1.8.0
43
+ google-auth-httplib2==0.1.0
44
+ google-auth-oauthlib==0.4.6
45
+ google-auth==2.3.3
46
+ google-pasta==0.2.0
47
+ googleapis-common-protos==1.54.0
48
+ grpcio==1.43.0
49
+ h5py==3.6.0
50
+ httplib2==0.20.2
51
+ huggingface-hub==0.2.1
52
+ hypothesis==6.53.0
53
+ idna==3.3
54
+ importlib-metadata==4.10.0
55
+ importlib-resources==5.4.0
56
+ ipython==7.31.0
57
+ jax==0.3.15
58
+ jaxlib==0.3.15
59
+ jedi==0.18.1
60
+ jiwer==2.3.0
61
+ joblib==1.1.0
62
+ keras-preprocessing==1.1.2
63
+ keras==2.7.0
64
+ kiwisolver==1.3.2
65
+ libclang==12.0.0
66
+ librosa==0.9.2
67
+ libtpu-nightly==0.1.dev20220722
68
+ llvmlite==0.39.0
69
+ markdown==3.3.6
70
+ matplotlib-inline==0.1.3
71
+ matplotlib==3.5.1
72
+ ml-collections==0.1.0
73
+ msgpack==1.0.3
74
+ multidict==5.2.0
75
+ multiprocess==0.70.12.2
76
+ numba==0.56.0
77
+ numpy==1.22.0
78
+ oauth2client==4.1.3
79
+ oauthlib==3.1.1
80
+ opt-einsum==3.3.0
81
+ optax==0.1.3
82
+ packaging==21.3
83
+ pandas==1.3.5
84
+ parso==0.8.3
85
+ pathtools==0.1.2
86
+ pexpect==4.8.0
87
+ pickleshare==0.7.5
88
+ pillow==9.0.0
89
+ pip==22.2.1
90
+ pkg-resources==0.0.0
91
+ pooch==1.6.0
92
+ promise==2.3
93
+ prompt-toolkit==3.0.24
94
+ protobuf==3.19.1
95
+ psutil==5.9.0
96
+ ptyprocess==0.7.0
97
+ pyarrow==6.0.1
98
+ pyasn1-modules==0.2.8
99
+ pyasn1==0.4.8
100
+ pycparser==2.21
101
+ pyctcdecode==0.4.0
102
+ pygments==2.11.1
103
+ pygtrie==2.5.0
104
+ pyparsing==3.0.6
105
+ python-dateutil==2.8.2
106
+ python-levenshtein==0.12.2
107
+ pytz==2021.3
108
+ pyyaml==6.0
109
+ regex==2021.11.10
110
+ requests-oauthlib==1.3.0
111
+ requests==2.27.0
112
+ resampy==0.3.1
113
+ responses==0.18.0
114
+ rich==11.2.0
115
+ rsa==4.8
116
+ sacremoses==0.0.46
117
+ scikit-learn==1.1.1
118
+ scipy==1.7.3
119
+ sentry-sdk==1.5.2
120
+ setuptools==44.0.0
121
+ shortuuid==1.0.8
122
+ six==1.16.0
123
+ smmap==5.0.0
124
+ sortedcontainers==2.4.0
125
+ soundfile==0.10.3.post1
126
+ sox==1.4.1
127
+ subprocess32==3.5.4
128
+ tensorboard-data-server==0.6.1
129
+ tensorboard-plugin-wit==1.8.0
130
+ tensorboard==2.7.0
131
+ tensorflow-cpu==2.7.0
132
+ tensorflow-datasets==4.4.0
133
+ tensorflow-estimator==2.7.0
134
+ tensorflow-io-gcs-filesystem==0.23.1
135
+ tensorflow-metadata==1.5.0
136
+ tensorflow==2.7.0
137
+ tensorstore==0.1.21
138
+ termcolor==1.1.0
139
+ threadpoolctl==3.1.0
140
+ tokenizers==0.11.2
141
+ toolz==0.11.2
142
+ torch==1.12.0
143
+ torchaudio==0.12.0+cpu
144
+ tqdm==4.62.3
145
+ traitlets==5.1.1
146
+ transformers==4.21.0
147
+ typing-extensions==4.3.0
148
+ uritemplate==3.0.1
149
+ urllib3==1.26.7
150
+ wandb==0.12.9
151
+ wcwidth==0.2.5
152
+ werkzeug==2.0.2
153
+ wheel==0.37.1
154
+ wrapt==1.13.3
155
+ xxhash==2.0.2
156
+ yarl==1.7.2
157
+ yaspin==2.1.0
158
+ zipp==3.7.0
wandb/run-20220805_152536-2fzkf8n5/files/wandb-metadata.json ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.4.0-1043-gcp-x86_64-with-glibc2.29",
3
+ "python": "3.8.10",
4
+ "heartbeatAt": "2022-08-05T15:25:40.370466",
5
+ "startedAt": "2022-08-05T15:25:36.782402",
6
+ "docker": null,
7
+ "cpu_count": 96,
8
+ "cuda": null,
9
+ "args": [
10
+ "--model_name_or_path=facebook/wav2vec2-xls-r-1b",
11
+ "--hub_model_id=NbAiLab/wav2vec2-1b-npsc-nst-tpu",
12
+ "--tokenizer_name=./",
13
+ "--output_dir=./",
14
+ "--overwrite_output_dir",
15
+ "--num_train_epochs=40",
16
+ "--per_device_train_batch_size=2",
17
+ "--per_device_eval_batch_size=2",
18
+ "--gradient_accumulation_steps=2",
19
+ "--precision=full_mixed",
20
+ "--matmul_precision=bfloat16",
21
+ "--multisteps",
22
+ "--learning_rate=2e-4",
23
+ "--warmup_steps=2000",
24
+ "--length_column_name=input_length",
25
+ "--evaluation_strategy=steps",
26
+ "--text_column_name=text",
27
+ "--save_steps=5000",
28
+ "--eval_steps=5000",
29
+ "--logging_steps=100",
30
+ "--layerdrop=0.041",
31
+ "--attention_dropout=0.094",
32
+ "--activation_dropout=0.055",
33
+ "--hidden_dropout=0.047",
34
+ "--save_total_limit=5",
35
+ "--freeze_feature_encoder",
36
+ "--feat_proj_dropout=0.04",
37
+ "--mask_time_prob=0.082",
38
+ "--mask_time_length=10",
39
+ "--mask_feature_prob=0.25",
40
+ "--mask_feature_length=64",
41
+ "--gradient_checkpointing",
42
+ "--min_duration_in_seconds=0.5",
43
+ "--max_duration_in_seconds=30.0",
44
+ "--use_auth_token",
45
+ "--seed=42",
46
+ "--group_by_length",
47
+ "--do_train",
48
+ "--do_eval",
49
+ "--push_to_hub",
50
+ "--preprocessing_num_workers=32",
51
+ "--ctc_zero_infinity",
52
+ "--do_lower_case",
53
+ "--wandb_project=wav2vec2",
54
+ "--wandb_name=wav2vec2-1b-npsc-nst-tpu",
55
+ "--remove_punctuation"
56
+ ],
57
+ "state": "running",
58
+ "program": "run_flax_speech_recognition_ctc.py",
59
+ "codePath": "run_flax_speech_recognition_ctc.py",
60
+ "git": {
61
+ "remote": "https://huggingface.co/NbAiLab/wav2vec2-1b-npsc-nst-tpu",
62
+ "commit": "e2b1320cc68c3ce129a1d654965e0d3eb44e0558"
63
+ },
64
+ "email": "versae@gmail.com",
65
+ "root": "/data/wav2vec2-1b-npsc-nst",
66
+ "host": "t1v-n-eedfb410-w-0",
67
+ "username": "javierr",
68
+ "executable": "/data/flax/bin/python"
69
+ }
wandb/run-20220805_152536-2fzkf8n5/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"_wandb": {"runtime": 1183}}
wandb/run-20220805_152536-2fzkf8n5/logs/debug-internal.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c137bfb0e98ec07c29d675736207e73a2a80e52db6b05737195bcc03723ca10a
3
+ size 51309
wandb/run-20220805_152536-2fzkf8n5/logs/debug.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:38b16fcef53f7c67610a4fe7d00663568b5a7c2f40e719c37aa96391ea55eb3e
3
+ size 6063
wandb/run-20220805_152536-2fzkf8n5/run-2fzkf8n5.wandb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a29f323346e50c6ded2df9a6993e83112cafa162f35900a0bbd05f05dddc5e29
3
+ size 250453
wandb/run-20220805_230151-2y71vcu4/files/code/run_flax_speech_recognition_ctc.py ADDED
@@ -0,0 +1,1631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2022 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the Flax library models for connectionist temporal classification (CTC) speech recognition.
18
+ """
19
+ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
+
21
+ import logging
22
+ import math
23
+ import os
24
+ import re
25
+ import sys
26
+ import time
27
+ from dataclasses import dataclass, field
28
+ from pathlib import Path
29
+ from typing import Any, Callable, Dict, List, Optional, Union
30
+
31
+ import datasets
32
+ import numpy as np
33
+ from datasets import DatasetDict, load_dataset, load_metric
34
+ from tqdm import tqdm
35
+
36
+ import flax
37
+ import jax
38
+ import jax.numpy as jnp
39
+ import optax
40
+ import transformers
41
+ import wandb as wandb
42
+ from flax import core, jax_utils, struct, traverse_util
43
+ from flax.jax_utils import unreplicate, pad_shard_unpad
44
+ from flax.training.common_utils import get_metrics, shard, shard_prng_key
45
+ from huggingface_hub import Repository
46
+ from models import Wav2Vec2Config, FlaxWav2Vec2ForCTC
47
+ from optax._src import linear_algebra
48
+ from transformers import (
49
+ AutoFeatureExtractor,
50
+ AutoProcessor,
51
+ AutoTokenizer,
52
+ HfArgumentParser,
53
+ TrainingArguments,
54
+ is_tensorboard_available,
55
+ set_seed,
56
+ )
57
+ from transformers.file_utils import get_full_repo_name
58
+ from transformers.utils import check_min_version
59
+ from transformers.utils.versions import require_version
60
+
61
+
62
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
63
+ check_min_version("4.17.0.dev0")
64
+
65
+ require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
66
+
67
+ logger = logging.getLogger(__name__)
68
+
69
+
70
+ @flax.struct.dataclass
71
+ class ModelArguments:
72
+ """
73
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
74
+ """
75
+
76
+ model_name_or_path: str = field(
77
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
78
+ )
79
+ config_name: Optional[str] = field(
80
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
81
+ )
82
+ tokenizer_name: Optional[str] = field(
83
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
84
+ )
85
+ feature_extractor_name: Optional[str] = field(
86
+ default=None, metadata={"help": "feature extractor name or path if not the same as model_name"}
87
+ )
88
+ cache_dir: Optional[str] = field(
89
+ default=None,
90
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
91
+ )
92
+ use_fast_tokenizer: bool = field(
93
+ default=True,
94
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
95
+ )
96
+ model_revision: str = field(
97
+ default="main",
98
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
99
+ )
100
+ use_auth_token: bool = field(
101
+ default=False,
102
+ metadata={
103
+ "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
104
+ "with private models)."
105
+ },
106
+ )
107
+ freeze_feature_encoder: bool = field(
108
+ default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
109
+ )
110
+ attention_dropout: float = field(
111
+ default=0.0, metadata={"help": "The dropout ratio for the attention probabilities."}
112
+ )
113
+ activation_dropout: float = field(
114
+ default=0.1,
115
+ metadata={
116
+ "help": "The hidden activation dropout probability in the embeddings, encoder, and pooler."
117
+ },
118
+ )
119
+ hidden_dropout: float = field(
120
+ default=0.1,
121
+ metadata={
122
+ "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler."
123
+ },
124
+ )
125
+ feat_proj_dropout: float = field(
126
+ default=0.0,
127
+ metadata={
128
+ "help": "The feat proj dropout probability for feature encoder representations."
129
+ },
130
+ )
131
+ final_dropout: float = field(
132
+ default=0.0,
133
+ metadata={"help": "The dropout probability for the final projection layer."},
134
+ )
135
+ mask_time_prob: float = field(
136
+ default=0.1,
137
+ metadata={
138
+ "help": "The spec aug dropout probability for feature encoder representations."
139
+ },
140
+ )
141
+ mask_time_length: int = field(
142
+ default=10,
143
+ metadata={"help": "Length of vector span to mask along the time axis."},
144
+ )
145
+ mask_feature_prob: float = field(
146
+ default=0.0,
147
+ metadata={
148
+ "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector"
149
+ "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis."
150
+ },
151
+ )
152
+ mask_feature_length: int = field(
153
+ default=10,
154
+ metadata={"help": "Length of vector span to mask along the feature axis."},
155
+ )
156
+ layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."})
157
+ ctc_loss_reduction: Optional[str] = field(
158
+ default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
159
+ )
160
+ ctc_zero_infinity: Optional[bool] = field(
161
+ default=False, metadata={"help": "If True, will try yo aboud the CTC loss goinf to infinity."}
162
+ )
163
+
164
+
165
+ @flax.struct.dataclass
166
+ class DataTrainingArguments:
167
+ """
168
+ Arguments pertaining to what data we are going to input our model for training and eval.
169
+ """
170
+
171
+ dataset_name: str = field(
172
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
173
+ )
174
+ dataset_config_name: Optional[str] = field(
175
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
176
+ )
177
+ text_column: Optional[str] = field(
178
+ default=None,
179
+ metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
180
+ )
181
+ dataset_cache_dir: Optional[str] = field(
182
+ default=None, metadata={"help": "Path to cache directory for saving and loading datasets"}
183
+ )
184
+ overwrite_cache: bool = field(
185
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
186
+ )
187
+ preprocessing_num_workers: Optional[int] = field(
188
+ default=None,
189
+ metadata={"help": "The number of processes to use for the preprocessing."},
190
+ )
191
+ max_train_samples: Optional[int] = field(
192
+ default=None,
193
+ metadata={
194
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
195
+ "value if set."
196
+ },
197
+ )
198
+ max_eval_samples: Optional[int] = field(
199
+ default=None,
200
+ metadata={
201
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
202
+ "value if set."
203
+ },
204
+ )
205
+ max_test_samples: Optional[int] = field(
206
+ default=None,
207
+ metadata={
208
+ "help": "For debugging purposes or quicker training, truncate the number of test examples to this "
209
+ "value if set."
210
+ },
211
+ )
212
+ audio_column_name: str = field(
213
+ default="audio",
214
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
215
+ )
216
+ text_column_name: str = field(
217
+ default="text",
218
+ metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
219
+ )
220
+ max_duration_in_seconds: float = field(
221
+ default=20.0,
222
+ metadata={
223
+ "help": "Truncate audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`"
224
+ },
225
+ )
226
+ min_duration_in_seconds: float = field(
227
+ default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
228
+ )
229
+ max_label_length: Optional[int] = field(
230
+ default=512,
231
+ metadata={
232
+ "help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
233
+ "than this will be filtered."
234
+ },
235
+ )
236
+ min_label_length: Optional[int] = field(
237
+ default=2,
238
+ metadata={
239
+ "help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
240
+ "than this will be filtered."
241
+ },
242
+ )
243
+ pad_input_to_multiple_of: Optional[int] = field(
244
+ default=32000,
245
+ metadata={
246
+ "help": "If set will pad the input sequence to a multiple of the provided value. "
247
+ "This is important to avoid triggering recompilations on TPU."
248
+ },
249
+ )
250
+ pad_target_to_multiple_of: Optional[int] = field(
251
+ default=None,
252
+ metadata={
253
+ "help": "If set will pad the target sequence to a multiple of the provided value. "
254
+ "This is important to avoid triggering recompilations on TPU."
255
+ },
256
+ )
257
+ preprocessing_only: bool = field(
258
+ default=False,
259
+ metadata={
260
+ "help": "Whether to only do data preprocessing and skip training. "
261
+ "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
262
+ "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
263
+ "so that the cached datasets can consequently be loaded in distributed training"
264
+ },
265
+ )
266
+ train_split_name: str = field(
267
+ default="train",
268
+ metadata={
269
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
270
+ },
271
+ )
272
+ eval_split_name: str = field(
273
+ default="validation",
274
+ metadata={
275
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
276
+ },
277
+ )
278
+ do_lower_case: bool = field(
279
+ default=True,
280
+ metadata={"help": "Whether the target text should be lower cased."},
281
+ )
282
+ wandb_project: str = field(
283
+ default="flax-speech-recognition-ctc",
284
+ metadata={"help": "The name of the wandb project."},
285
+ )
286
+ wandb_name: str = field(
287
+ default=None,
288
+ metadata={"help": "The name of the wandb run."},
289
+ )
290
+ wandb_job_type: str = field(
291
+ default="CTC",
292
+ metadata={"help": "The name of the wandb job type."},
293
+ )
294
+ test_split_name: str = field(
295
+ default="test",
296
+ metadata={"help": "The name of the test data set split to use (via the datasets library). Defaults to 'test'"},
297
+ )
298
+ remove_punctuation: bool = field(
299
+ default=False, metadata={"help": "Whether or not to remove punctuation during training."}
300
+ )
301
+ skip_steps: Optional[int] = field(
302
+ default=0,
303
+ metadata={
304
+ "help": "Skip this number of steps. Useful to continue training"
305
+ },
306
+ )
307
+
308
+
309
+ # @flax.struct.dataclass
310
+ @dataclass
311
+ class FlaxTrainingArguments(TrainingArguments):
312
+ precision: str = field(
313
+ default="full",
314
+ metadata={
315
+ "help": "Whether to enable mixed-precision training. If true, the optimizer is stored in half-precision (bfloat16) and computations are executed in half-precision"
316
+ "**Note that this only specifies the dtype of the computation and optimizer state. It does not influence the dtype of model parameters.**"
317
+ },
318
+ )
319
+ matmul_precision: str = field(
320
+ default="default",
321
+ metadata={
322
+ "help": "Default floating-point precision of internal computations used in TPU matrix multiplications and convolutions. "
323
+ "This configuration option controls the default precision for JAX operations that take an optional precision argument (e.g. `lax.conv_general_dilated` and `lax.dot`). "
324
+ "This configuration option does not change the behaviours of such calls with explicit precision arguments; "
325
+ "it only changes the behaviors of calls with no such argument provided. "
326
+ "One of `['highest', 'float32', 'high', 'bfloat16_3x', 'default', 'bfloat16', 'fastest', None]`."
327
+ },
328
+ )
329
+ multisteps: bool = field(
330
+ default=False,
331
+ metadata={
332
+ "help": "Whether to use Optax MultiSteps for gradient accumulation. If `False` (default) and `gradient_accumulation_steps > 1`, "
333
+ "a custom gradient accumulation implementation will be employed."
334
+ },
335
+ )
336
+
337
+
338
+ def to_fp32(t):
339
+ return jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
340
+
341
+
342
+ def to_bf16(t):
343
+ return jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t)
344
+
345
+
346
+ class MixedPrecisionTrainState(struct.PyTreeNode):
347
+ """Train state for use with a single Optax optimizer.
348
+ Adapted from flax train_state https://github.com/google/flax/blob/main/flax/training/train_state.py
349
+
350
+ Synopsis::
351
+
352
+ state = TrainState.create(
353
+ apply_fn=model.apply,
354
+ params=variables['params'],
355
+ tx=tx)
356
+ grad_fn = jax.grad(make_loss_fn(state.apply_fn))
357
+ for batch in data:
358
+ grads = grad_fn(state.params, batch)
359
+ state = state.apply_gradients(grads=grads)
360
+
361
+ Args:
362
+ step: Counter starts at 0 and is incremented by every call to
363
+ `.apply_gradients()`.
364
+ apply_fn: Usually set to `model.apply()`. Kept in this dataclass for
365
+ convenience to have a shorter params list for the `train_step()` function
366
+ in your training loop.
367
+ params: The parameters to be updated by `tx` and used by `apply_fn`.
368
+ tx: An Optax gradient transformation.
369
+ opt_state: The state for `tx`.
370
+ dropout_rng: PRNG key for stochastic operations.
371
+ bf16: Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training.
372
+ """
373
+
374
+ step: int
375
+ apply_fn: Callable = struct.field(pytree_node=False)
376
+ get_attention_mask_fn: Callable = struct.field(pytree_node=False)
377
+ params: core.FrozenDict[str, Any]
378
+ tx: optax.GradientTransformation = struct.field(pytree_node=False)
379
+ opt_state: optax.OptState
380
+ dropout_rng: jnp.ndarray
381
+ max_grad_norm: Optional[float] = 1.0
382
+
383
+ def apply_gradients(self, *, grads, to_dtype, **kwargs):
384
+ """Updates `step`, `params`, `opt_state` and `**kwargs` in return value.
385
+
386
+ Note that internally this function calls `.tx.update()` followed by a call
387
+ to `optax.apply_updates()` to update `params` and `opt_state`.
388
+
389
+ Args:
390
+ grads: Gradients that have the same pytree structure as `.params`.
391
+ **kwargs: Additional dataclass attributes that should be `.replace()`-ed.
392
+
393
+ Returns:
394
+ An updated instance of `self` with `step` incremented by one, `params`
395
+ and `opt_state` updated by applying `grads`, and additional attributes
396
+ replaced as specified by `kwargs`.
397
+ """
398
+
399
+ # clip gradients by global l2 norm
400
+ casted_max_grad_norm = to_dtype(self.max_grad_norm)
401
+ g_norm = linear_algebra.global_norm(grads)
402
+ g_norm = jnp.maximum(casted_max_grad_norm, g_norm)
403
+ grads = jax.tree_util.tree_map(lambda t: (t / g_norm) * casted_max_grad_norm, grads)
404
+
405
+ # perform update step in fp32 and subsequently downcast optimizer states if mixed precision training
406
+ # grads and opt_state in bf16 (need to upcast), params in fp32 (leave as is)
407
+ updates, new_opt_state = self.tx.update(to_fp32(grads), to_fp32(self.opt_state), self.params)
408
+
409
+ new_params = optax.apply_updates(self.params, updates)
410
+ return self.replace(
411
+ step=self.step + 1,
412
+ params=new_params,
413
+ opt_state=to_dtype(new_opt_state),
414
+ **kwargs,
415
+ )
416
+
417
+ @classmethod
418
+ def create(cls, *, apply_fn, params, tx, to_dtype, **kwargs):
419
+ """Creates a new instance with `step=0` and initialized `opt_state`."""
420
+ # downcast optimizer state to bf16 if mixed-precision training
421
+ opt_state = tx.init(to_dtype(params)) if tx is not None else None
422
+ return cls(
423
+ step=0,
424
+ apply_fn=apply_fn,
425
+ params=params,
426
+ tx=tx,
427
+ opt_state=opt_state,
428
+ **kwargs,
429
+ )
430
+
431
+ def replicate(self):
432
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
433
+
434
+
435
+ @flax.struct.dataclass
436
+ class FlaxDataCollatorSpeechSeq2SeqWithPadding:
437
+ """
438
+ Data collator that will dynamically pad the inputs received.
439
+ Args:
440
+ processor ([`Wav2Vec2Processor`])
441
+ The processor used for proccessing the data.
442
+ decoder_start_token_id (:obj: `int`)
443
+ The begin-of-sentence of the decoder.
444
+ input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
445
+ Select a strategy to pad the returned input sequences (according to the model's padding side and padding index)
446
+ among:
447
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
448
+ sequence if provided).
449
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
450
+ maximum acceptable input length for the model if that argument is not provided.
451
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
452
+ different lengths).
453
+ target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
454
+ Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
455
+ See above for details.
456
+ max_input_length (:obj:`float`, `optional`):
457
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
458
+ pad_input_to_multiple_of (:obj:`int`, `optional`):
459
+ If set will pad the input sequence to a multiple of the provided value.
460
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
461
+ 7.5 (Volta).
462
+ pad_target_to_multiple_of (:obj:`int`, `optional`):
463
+ If set will pad the target sequence to a multiple of the provided value.
464
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
465
+ 7.5 (Volta).
466
+ """
467
+
468
+ processor: Any
469
+ input_padding: Union[bool, str] = "longest"
470
+ label_padding: Union[bool, str] = "max_length"
471
+ pad_input_to_multiple_of: Optional[int] = None
472
+ pad_to_multiple_of_label: Optional[int] = None
473
+ max_input_length: Optional[float] = None
474
+ max_label_length: Optional[float] = None
475
+
476
+ def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
477
+ # split inputs and labels since they have to be of different lengths and need
478
+ # different padding methods
479
+ input_features = [{"input_values": feature["input_values"]} for feature in features]
480
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
481
+
482
+ # reformat list to dict and set to pytorch format
483
+ batch = self.processor.feature_extractor.pad(
484
+ input_features,
485
+ max_length=self.max_input_length,
486
+ padding=self.input_padding,
487
+ pad_to_multiple_of=self.pad_input_to_multiple_of,
488
+ return_tensors="np",
489
+ )
490
+
491
+ labels_batch = self.processor.tokenizer.pad(
492
+ label_features,
493
+ max_length=self.max_label_length,
494
+ padding=self.label_padding,
495
+ pad_to_multiple_of=self.pad_to_multiple_of_label,
496
+ return_tensors="np",
497
+ )
498
+
499
+ labels = labels_batch["input_ids"]
500
+ labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1))
501
+ labels = labels.filled(fill_value=-100)
502
+
503
+ batch["labels"] = labels
504
+
505
+ return batch
506
+
507
+
508
+ def get_grouped_indices(
509
+ dataset, batch_size: int, rng: Optional[List[int]] = None, mega_batch_mult: Optional[int] = None
510
+ ) -> np.array:
511
+ """
512
+ Adapted from the `get_length_grouped_indices` function in the PyTorch Trainer utils file (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_pt_utils.py#L486)
513
+ Function that returns a list of indices in which each slice of `batch_size` consecutive indices correspond to elements of similar
514
+ lengths. To do this, the indices are:
515
+
516
+ - randomly permuted (if a JAX rng is specified)
517
+ - grouped in mega-batches of size `mega_batch_mult * batch_size`
518
+ - sorted by length in each mega-batch
519
+
520
+ The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
521
+ maximum length placed first, so that an OOM happens sooner rather than later.
522
+ """
523
+ lengths = dataset["input_length"]
524
+
525
+ # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
526
+ if mega_batch_mult is None:
527
+ mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
528
+ # Just in case, for tiny datasets
529
+ if mega_batch_mult == 0:
530
+ mega_batch_mult = 1
531
+
532
+ # We need to use JAX for the random permutation as the PRNG key will be set based on the seed outside of the sampler.
533
+ num_samples = len(lengths)
534
+ indices = jax.random.permutation(rng, np.arange(num_samples)) if rng is not None else np.arange(num_samples)
535
+
536
+ megabatch_size = mega_batch_mult * batch_size
537
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
538
+ megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches]
539
+
540
+ # The rest is to get the biggest batch first.
541
+ # Since each megabatch is sorted by descending length, the longest element is the first
542
+ megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
543
+ max_idx = np.argmax(megabatch_maximums).item()
544
+ # Switch to put the longest batch in first position
545
+ # (note that this is different to the PT grouped sampler in which we only put the longest element in the first position, and not its batch)
546
+ megabatches[0], megabatches[max_idx] = megabatches[max_idx], megabatches[0]
547
+
548
+ megabatches = np.array([i for megabatch in megabatches for i in megabatch])
549
+
550
+ return megabatches
551
+
552
+
553
+ def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray:
554
+ """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
555
+ the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned."""
556
+ num_samples = len(samples_idx)
557
+ if drop_last:
558
+ samples_to_remove = num_samples % batch_size
559
+ if samples_to_remove != 0:
560
+ samples_idx = samples_idx[:-samples_to_remove]
561
+ sections_split = num_samples // batch_size
562
+ samples_idx = samples_idx.reshape((sections_split, batch_size))
563
+ else:
564
+ sections_split = math.ceil(num_samples / batch_size)
565
+ samples_idx = np.array_split(samples_idx, sections_split)
566
+ return samples_idx
567
+
568
+
569
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
570
+ summary_writer.scalar("train_time", train_time, step)
571
+
572
+ train_metrics = get_metrics(train_metrics)
573
+ for key, vals in train_metrics.items():
574
+ tag = f"train_{key}"
575
+ for i, val in enumerate(vals):
576
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
577
+
578
+
579
+ def write_eval_metric(summary_writer, eval_metrics, step, pred_str=None):
580
+ for metric_name, value in eval_metrics.items():
581
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
582
+
583
+ if pred_str is not None:
584
+ # write output actual predictions for debugging
585
+ summary_writer.text("eval_predictions", "\n".join(pred_str), step)
586
+
587
+
588
+ def write_wandb_log(metrics, step, prefix=None):
589
+ if jax.process_index() == 0:
590
+ log_metrics = {}
591
+ for k, v in metrics.items():
592
+ if "layer" in k:
593
+ log_metrics[f"{k}/"] = v
594
+ elif prefix is not None:
595
+ log_metrics[f"{prefix}/{k}"] = v
596
+ else:
597
+ log_metrics[k] = v
598
+ wandb.log(log_metrics, step)
599
+
600
+
601
+ def write_wandb_pred(pred_str, label_str, step, num_log=50, prefix="eval"):
602
+ if jax.process_index() == 0:
603
+ # convert str data to a wandb compatible format
604
+ str_data = [[label_str[i], pred_str[i]] for i in range(len(pred_str))]
605
+ # we'll log the first 50 predictions for each epoch
606
+ wandb.log(
607
+ {
608
+ f"{prefix}/step_{int(step / 1000)}k": wandb.Table(
609
+ columns=["label_str", "pred_str"], data=str_data[:num_log]
610
+ )
611
+ },
612
+ step,
613
+ )
614
+
615
+
616
+ def create_learning_rate_fn(
617
+ num_train_steps: int, num_warmup_steps: int, learning_rate: float
618
+ ) -> Callable[[int], jnp.array]:
619
+ """Returns a linear warmup, linear_decay learning rate function."""
620
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
621
+ decay_fn = optax.linear_schedule(
622
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
623
+ )
624
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
625
+ return schedule_fn
626
+
627
+
628
+ def ctc_loss(
629
+ logits,
630
+ logits_attention_mask,
631
+ labels,
632
+ blank_id,
633
+ loss_reduction="mean",
634
+ output_emission_dict=False,
635
+ log_epsilon=-100000.0,
636
+ ):
637
+ """Computes CTC loss.
638
+ This function performs forward computation over an FSA with `N * 2` states
639
+ where `N` is the max number of labels. The states are split into two groups:
640
+ Phi states and emission states. a phi-state accepts repetition of
641
+ phi (blank)-symbols and transits to emission state when the correct label is
642
+ observed. An emission state accepts repetition of the label and transits to
643
+ the next phi states at any time (so called epsilon-transition).
644
+ Below, `B` denotes the batch size, `T` denotes the time steps in `logits`,
645
+ and `N` denotes the time steps in `labels`.
646
+ Args:
647
+ logits: (B, T, K)-array containing log-probabilities of each class.
648
+ logitpaddings: (B, T)-array. Padding indicators for `logits`.
649
+ labels: (B, N)-array containing reference integer labels.
650
+ labelpaddings: (B, N)-array. Padding indicators for `labels`. Currently,
651
+ `labels` must be right-padded, i.e. each row of `labelpaddings` must be
652
+ repetition of zeroes, followed by repetition of ones.
653
+ blank_id: Id for blank token.
654
+ loss_reduction: one of "mean", "sum", "default"
655
+ - "none": no reduction is applied.
656
+ - "mean": output loss will be divided by target lengths and then the
657
+ mean over the batch is taken.
658
+ - "sum": output loss are summed over batch
659
+ output_emission_dict: whether to output additional information about the emission probs
660
+ Returns:
661
+ A pair of `(per_seq_loss, aux)`.
662
+ per_seq_loss:
663
+ (B,)-array containing loss values for each sequence in the batch.
664
+ aux: Dictionary containing interim variables used for computing losses.
665
+ aux['logalpha_phi']: (T, B, N+1)-array. Log-forward-probabilities of each
666
+ phi-state corresponding to the n-th label.
667
+ aux['logalpha_emit']: (T, B, N)-array. Log-forward-probabilities of each
668
+ emission-state corresponding to the n-th label.
669
+ aux['logprobs_phi']: (T, B, 1)-array. Probability of the phi-symbol
670
+ corresponding to each time frame.
671
+ aux['logprobs_emit']: (T, B, N)-array. Probability of the n-th label
672
+ corresponding to each time frame.
673
+ """
674
+ # label paddings are indicated by -100
675
+ labelpaddings = labels < 0
676
+ # logit paddings are the inverse of attention_mask
677
+ logitpaddings = ~logits_attention_mask
678
+
679
+ # Copied from https://github.com/tensorflow/lingvo/blob/master/lingvo/jax/layers/ctc_objectives.py
680
+ batchsize, unused_maxinputlen, num_classes = logits.shape
681
+ batchsize_, maxlabellen = labels.shape
682
+
683
+ logprobs = jax.nn.log_softmax(logits)
684
+ labellens = maxlabellen - jnp.sum(labelpaddings, axis=1).astype(jnp.int32)
685
+
686
+ # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1].
687
+ repeat = (labels[:, :-1] == labels[:, 1:]).astype(jnp.float32)
688
+ repeat = jnp.pad(repeat, ((0, 0), (0, 1)))
689
+
690
+ logprobs_phi = logprobs[:, :, blank_id : blank_id + 1] # [B, T, 1]
691
+ logprobs_phi = jnp.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1]
692
+
693
+ one_hot = jax.nn.one_hot(labels, num_classes=num_classes) # [B, N, K]
694
+ logprobs_emit = jnp.einsum("btk,bnk->btn", logprobs, one_hot)
695
+ logprobs_emit = jnp.transpose(logprobs_emit, (1, 0, 2)) # [T, B, N]
696
+
697
+ logalpha_phi_init = jnp.ones((batchsize, maxlabellen + 1)) * log_epsilon # [B, N]
698
+ logalpha_phi_init = logalpha_phi_init.at[:, 0].set(0.0)
699
+ logalpha_emit_init = jnp.ones((batchsize, maxlabellen)) * log_epsilon # [B, N]
700
+
701
+ def loop_body(prev, x):
702
+ prev_phi, prev_emit = prev
703
+ # emit-to-phi epsilon transition, except if the next label is repetition
704
+ prev_phi_orig = prev_phi
705
+ prev_phi = prev_phi.at[:, 1:].set(jnp.logaddexp(prev_phi[:, 1:], prev_emit + log_epsilon * repeat))
706
+
707
+ logprob_emit, logprob_phi, pad = x
708
+
709
+ # phi-to-emit transition
710
+ next_emit = jnp.logaddexp(prev_phi[:, :-1] + logprob_emit, prev_emit + logprob_emit)
711
+ # self-loop transition
712
+ next_phi = prev_phi + logprob_phi
713
+ # emit-to-phi blank transition only when the next label is repetition
714
+ next_phi = next_phi.at[:, 1:].set(
715
+ jnp.logaddexp(next_phi[:, 1:], prev_emit + logprob_phi + log_epsilon * (1.0 - repeat))
716
+ )
717
+
718
+ pad = pad.reshape((batchsize, 1))
719
+ next_emit = pad * prev_emit + (1.0 - pad) * next_emit
720
+ next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi
721
+
722
+ return (next_phi, next_emit), (next_phi, next_emit)
723
+
724
+ xs = (logprobs_emit, logprobs_phi, logitpaddings.transpose((1, 0)))
725
+ _, (logalpha_phi, logalpha_emit) = jax.lax.scan(loop_body, (logalpha_phi_init, logalpha_emit_init), xs)
726
+
727
+ # last row needs to be updated with the last epsilon transition
728
+ logalpha_phi_last = logalpha_phi[-1].at[:, 1:].set(jnp.logaddexp(logalpha_phi[-1, :, 1:], logalpha_emit[-1]))
729
+ logalpha_phi = logalpha_phi.at[-1].set(logalpha_phi_last)
730
+
731
+ # extract per_seq_loss
732
+ one_hot = jax.nn.one_hot(labellens, num_classes=maxlabellen + 1) # [B, N+1]
733
+ per_seq_loss = -jnp.einsum("bn,bn->b", logalpha_phi_last, one_hot)
734
+
735
+ if loss_reduction == "mean":
736
+ target_lengths = labelpaddings.shape[-1] - labelpaddings.sum(axis=-1)
737
+ loss = (per_seq_loss / target_lengths).mean()
738
+ elif loss_reduction == "sum":
739
+ loss = per_seq_loss.sum()
740
+ else:
741
+ loss = per_seq_loss
742
+
743
+ if not output_emission_dict:
744
+ return loss
745
+
746
+ return loss, {
747
+ "logalpha_phi": logalpha_phi,
748
+ "logalpha_emit": logalpha_emit,
749
+ "logprobs_phi": logprobs_phi,
750
+ "logprobs_emit": logprobs_emit,
751
+ }
752
+
753
+
754
+ def make_dataset(data_args, seed=42):
755
+ # Pre-processing dataset
756
+ import re
757
+
758
+ def map_nst(entry):
759
+ text = entry["text"].lower()
760
+ text = text.replace("(...vær stille under dette opptaket...)", "")
761
+ text = re.sub('[áàâ]', 'a', text)
762
+ text = re.sub('[ä]', 'æ', text)
763
+ text = re.sub('[éèëê]', 'e', text)
764
+ text = re.sub('[íìïî]', 'i', text)
765
+ text = re.sub('[óòöô]', 'o', text)
766
+ text = re.sub('[ö]', 'ø', text)
767
+ text = re.sub('[ç]', 'c', text)
768
+ text = re.sub('[úùüû]', 'u', text)
769
+ # text = re.sub('\\(?=(Punktum|Komma|Utropstegn|Spørsmålstegn))', ' ', text)
770
+ text = re.sub('\s+', ' ', text)
771
+ return {"text": text}
772
+
773
+ def filter_nst(entry):
774
+ if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)):
775
+ return False # Too short
776
+ if re.match(entry["type"], "pIW|CA"):
777
+ return False # Spelling out words
778
+ return True
779
+
780
+ def filter_npsc(entry):
781
+ # False if there are digits in the text
782
+ if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)):
783
+ return False # Too short
784
+ if re.search("\d", entry["text"]):
785
+ return False
786
+ return True
787
+
788
+ def map_npsc(entry):
789
+ batch = {"text": entry["text"].lower()}
790
+ batch["text"] = re.sub('[áàâ]', 'a', batch["text"])
791
+ batch["text"] = re.sub('[ä]', 'æ', batch["text"])
792
+ batch["text"] = re.sub('[éèëê]', 'e', batch["text"])
793
+ batch["text"] = re.sub('[íìïî]', 'i', batch["text"])
794
+ batch["text"] = re.sub('[óòöô]', 'o', batch["text"])
795
+ batch["text"] = re.sub('[ö]', 'ø', batch["text"])
796
+ batch["text"] = re.sub('[ç]', 'c', batch["text"])
797
+ batch["text"] = re.sub('[úùüû]', 'u', batch["text"])
798
+ batch["text"] = re.sub('\s', ' ', batch["text"])
799
+ batch["text"] = re.sub('<ee>', 'eee', batch["text"])
800
+ batch["text"] = re.sub('<qq>', 'qqq', batch["text"])
801
+ batch["text"] = re.sub('<mm>', 'mmm', batch["text"])
802
+ batch["text"] = re.sub('<inaudible>', 'xxx', batch["text"])
803
+ # batch["text"] = re.sub('<inaudible>', '?', batch["text"])
804
+ if "<" in batch["text"]:
805
+ raise ValueError(batch["text"])
806
+ return batch
807
+
808
+ nst = datasets.load_dataset("NbAiLab/NST", "no-close")
809
+ npsc = datasets.load_dataset("NbAiLab/NPSC", "16K_mp3")
810
+ # TODO NST_hesitate
811
+
812
+ split = len(npsc[data_args.train_split_name]) / (len(npsc[data_args.train_split_name]) + len(npsc[data_args.eval_split_name])) # Use same train/val ratio as NPSC
813
+ nst_train = nst[data_args.train_split_name].train_test_split(train_size=split, seed=seed)
814
+ nst[data_args.train_split_name] = nst_train["train"]
815
+ nst[data_args.eval_split_name] = nst_train["test"]
816
+
817
+ nst = nst.filter(filter_nst).map(
818
+ map_nst,
819
+ num_proc=data_args.preprocessing_num_workers,
820
+ desc="filtering NST",
821
+ ).shuffle(seed=seed)
822
+ npsc = npsc.filter(filter_npsc).map(
823
+ map_npsc,
824
+ num_proc=data_args.preprocessing_num_workers,
825
+ desc="filtering NPSC",
826
+ ).shuffle(seed=seed)
827
+
828
+ npsc_base = npsc.remove_columns([col for col in npsc[data_args.train_split_name].column_names if col not in ["text", "audio"]])
829
+ nst_base = nst.remove_columns([col for col in nst[data_args.train_split_name].column_names if col not in ["text", "audio"]])
830
+
831
+ combined = {}
832
+ for split in data_args.train_split_name, data_args.eval_split_name, data_args.test_split_name:
833
+ probs = np.array([len(nst_base[split]), len(npsc_base[split])]) # Weight by number of examples
834
+ probs = (probs / probs.sum()).tolist()
835
+ comb = datasets.interleave_datasets([nst_base[split], npsc_base[split]], probabilities=probs, seed=seed)
836
+ combined[split] = comb
837
+
838
+ return datasets.DatasetDict(**combined)
839
+
840
+ def main():
841
+ # 1. Parse input arguments
842
+ # See all possible arguments in src/transformers/training_args.py
843
+ # or by passing the --help flag to this script.
844
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
845
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FlaxTrainingArguments))
846
+
847
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
848
+ # If we pass only one argument to the script and it's the path to a json file,
849
+ # let's parse it to get our arguments.
850
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
851
+ else:
852
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
853
+
854
+ # 2. Setup logging
855
+ # Make one log on every process with the configuration for debugging.
856
+ logging.basicConfig(
857
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
858
+ datefmt="%m/%d/%Y %H:%M:%S",
859
+ handlers=[logging.StreamHandler(sys.stdout)],
860
+ )
861
+ # Set the verbosity to info of the Transformers logger.
862
+ # We only want one process per machine to log things on the screen.
863
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
864
+ if jax.process_index() == 0:
865
+ datasets.utils.logging.set_verbosity_warning()
866
+ transformers.utils.logging.set_verbosity_info()
867
+ else:
868
+ datasets.utils.logging.set_verbosity_error()
869
+ transformers.utils.logging.set_verbosity_error()
870
+
871
+ # Set up wandb run
872
+ if jax.process_index() == 0:
873
+ wandb.init(project=data_args.wandb_project, name=data_args.wandb_name, job_type=data_args.wandb_job_type)
874
+
875
+ logger.info("Training/evaluation parameters %s", training_args)
876
+
877
+ # Set the default TPU matmul precision and display the number of devices
878
+ jax.config.update("jax_default_matmul_precision", training_args.matmul_precision)
879
+ logger.info(f"JAX devices: {jax.device_count()}, matmul precision: {training_args.matmul_precision}")
880
+
881
+ # 4. Load dataset
882
+
883
+ set_seed(training_args.seed)
884
+ raw_datasets = make_dataset(data_args, seed=training_args.seed)
885
+
886
+ # raw_datasets = DatasetDict()
887
+
888
+ # if training_args.do_train:
889
+ # raw_datasets[data_args.train_split_name] = load_dataset(
890
+ # data_args.dataset_name,
891
+ # data_args.dataset_config_name,
892
+ # split=data_args.train_split_name,
893
+ # cache_dir=data_args.dataset_cache_dir,
894
+ # use_auth_token=True if model_args.use_auth_token else None,
895
+ # )
896
+
897
+ # if training_args.do_eval:
898
+ # raw_datasets[data_args.eval_split_name] = load_dataset(
899
+ # data_args.dataset_name,
900
+ # data_args.dataset_config_name,
901
+ # split=data_args.eval_split_name,
902
+ # cache_dir=data_args.dataset_cache_dir,
903
+ # use_auth_token=True if model_args.use_auth_token else None,
904
+ # )
905
+
906
+ # if training_args.do_predict:
907
+ # test_split = data_args.test_split_name.split("+")
908
+ # for split in test_split:
909
+ # raw_datasets[split] = load_dataset(
910
+ # data_args.dataset_name,
911
+ # data_args.dataset_config_name,
912
+ # split=split,
913
+ # cache_dir=data_args.dataset_cache_dir,
914
+ # use_auth_token=True if model_args.use_auth_token else None,
915
+ # )
916
+
917
+ if not training_args.do_train and not training_args.do_eval and not training_args.do_predict:
918
+ raise ValueError(
919
+ "Cannot not train, not do evaluation and not do prediction. At least one of "
920
+ "training, evaluation or prediction has to be done."
921
+ )
922
+
923
+ # if not training, there is no need to run multiple epochs
924
+ if not training_args.do_train:
925
+ training_args.num_train_epochs = 1
926
+
927
+ if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
928
+ raise ValueError(
929
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
930
+ "Make sure to set `--audio_column_name` to the correct audio column - one of "
931
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
932
+ )
933
+
934
+ if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names:
935
+ raise ValueError(
936
+ f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
937
+ "Make sure to set `--text_column_name` to the correct text column - one of "
938
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
939
+ )
940
+
941
+ # 5. Load pretrained model, tokenizer, and feature extractor
942
+ #
943
+ # Distributed training:
944
+ # The .from_pretrained methods guarantee that only one local process can concurrently
945
+ config = Wav2Vec2Config.from_pretrained(
946
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
947
+ cache_dir=model_args.cache_dir,
948
+ revision=model_args.model_revision,
949
+ use_auth_token=True if model_args.use_auth_token else None,
950
+ )
951
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
952
+ model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
953
+ cache_dir=model_args.cache_dir,
954
+ revision=model_args.model_revision,
955
+ use_auth_token=True if model_args.use_auth_token else None,
956
+ )
957
+ tokenizer = AutoTokenizer.from_pretrained(
958
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
959
+ cache_dir=model_args.cache_dir,
960
+ revision=model_args.model_revision,
961
+ use_auth_token=True if model_args.use_auth_token else None,
962
+ )
963
+ # update config according to training args, model args, and tokenizer attributes
964
+ config.update(
965
+ {
966
+ "feat_proj_dropout": model_args.feat_proj_dropout,
967
+ "attention_dropout": model_args.attention_dropout,
968
+ "hidden_dropout": model_args.hidden_dropout,
969
+ "final_dropout": model_args.final_dropout,
970
+ "mask_time_prob": model_args.mask_time_prob,
971
+ "mask_time_length": model_args.mask_time_length,
972
+ "mask_feature_prob": model_args.mask_feature_prob,
973
+ "mask_feature_length": model_args.mask_feature_length,
974
+ "gradient_checkpointing": training_args.gradient_checkpointing,
975
+ "layerdrop": model_args.layerdrop,
976
+ "ctc_loss_reduction": model_args.ctc_loss_reduction,
977
+ "ctc_zero_infinity": model_args.ctc_zero_infinity,
978
+ "pad_token_id": tokenizer.pad_token_id,
979
+ "vocab_size": tokenizer.vocab_size, # len(tokenizer),
980
+ "activation_dropout": model_args.activation_dropout,
981
+ }
982
+ )
983
+
984
+ if tokenizer.do_lower_case and data_args.dataset_name != "librispeech_asr":
985
+ raise ValueError(
986
+ "Setting the tokenizer attribute `do_lower_case` to `True` converts all input strings to "
987
+ "uppercase prior to tokenization. This should only be done when the tokenizer is built on an uppercased corpus,"
988
+ "i.e. for the dataset `librispeech_asr` only. If your dataset is not `librispeech_asr`, the tokenizer is mostly likely "
989
+ "built on an lowercased corpus. In this case, set `tokenizer.do_lower_case` to ``False`."
990
+ )
991
+
992
+ if training_args.precision == "full_mixed":
993
+ dtype = jnp.bfloat16
994
+ training_args.mixed_precision = True
995
+ elif training_args.precision == "half_mixed":
996
+ dtype = jnp.bfloat16
997
+ training_args.mixed_precision = False
998
+ else:
999
+ dtype = jnp.float32
1000
+ training_args.mixed_precision = False
1001
+
1002
+ try:
1003
+ model = FlaxWav2Vec2ForCTC.from_pretrained(
1004
+ model_args.model_name_or_path,
1005
+ config=config,
1006
+ dtype=dtype,
1007
+ cache_dir=model_args.cache_dir,
1008
+ revision=model_args.model_revision,
1009
+ use_auth_token=True if model_args.use_auth_token else None,
1010
+ )
1011
+ except:
1012
+ model = FlaxWav2Vec2ForCTC.from_pretrained(
1013
+ model_args.model_name_or_path,
1014
+ config=config,
1015
+ dtype=dtype,
1016
+ cache_dir=model_args.cache_dir,
1017
+ revision=model_args.model_revision,
1018
+ use_auth_token=True if model_args.use_auth_token else None,
1019
+ from_pt=True,
1020
+ )
1021
+
1022
+ # 6. Resample speech dataset ALWAYS
1023
+ raw_datasets = raw_datasets.cast_column(
1024
+ data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
1025
+ )
1026
+
1027
+ # 7. Preprocessing the datasets.
1028
+ # We need to read the audio files as arrays and tokenize the targets.
1029
+ max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
1030
+ min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate)
1031
+ max_target_length = data_args.max_label_length
1032
+ min_target_length = data_args.min_label_length
1033
+ pad_input_to_multiple_of = data_args.pad_input_to_multiple_of
1034
+ audio_column_name = data_args.audio_column_name
1035
+ num_workers = data_args.preprocessing_num_workers
1036
+ text_column_name = data_args.text_column_name
1037
+ model_input_name = feature_extractor.model_input_names[0]
1038
+ do_lower_case = data_args.do_lower_case
1039
+ dataset_name = data_args.dataset_name
1040
+ chars_to_ignore = ', ? . ! - ; : " “ % ‘ ” ?'.split(" ")
1041
+ chars_to_ignore_regex = f'[{"".join(chars_to_ignore)}]'
1042
+ # gigaspeech_punctuation = {" <comma>": ",", " <period>": ".", " <questionmark>": "?", " <exclamationpoint>": "!"}
1043
+ # gigaspeech_disfluencies = ["<other>", "<sil>"]
1044
+ # swb_disfluencies = ["[noise]", "[laughter]", "[silence]", "<a_aside>", "<b_aside>", "<e_aside>", "[laughter-",
1045
+ # "[vocalized-noise]", "_1"]
1046
+ # swb_punctuations = ["{", "}", "[", "]-", "]"]
1047
+ # earnings_disfluencies = ["<crosstalk>", "<affirmative>", "<inaudible>", "inaudible", "<laugh>", "<unk>"]
1048
+ ignore_segments = ["ignore_time_segment_in_scoring", "<noise>", "<music>", "[noise]", "[laughter]", "[silence]",
1049
+ "[vocalized-noise]", "<crosstalk>", "<affirmative>", "<inaudible>", "<laugh>", "<other>", "<sil>", ""]
1050
+
1051
+ if training_args.do_train and data_args.max_train_samples is not None:
1052
+ raw_datasets[data_args.train_split_name] = raw_datasets[data_args.train_split_name].select(range(data_args.max_train_samples))
1053
+
1054
+ if training_args.do_eval and data_args.max_eval_samples is not None:
1055
+ raw_datasets[data_args.eval_split_name] = raw_datasets[data_args.eval_split_name].select(range(data_args.max_eval_samples))
1056
+
1057
+ if training_args.do_predict and data_args.max_test_samples is not None:
1058
+ raw_datasets[data_args.test_split_name] = raw_datasets[data_args.test_split_name].select(range(data_args.max_eval_samples))
1059
+
1060
+ if training_args.do_train and data_args.remove_punctuation:
1061
+
1062
+ def remove_punctuation(batch):
1063
+ batch[text_column_name] = (
1064
+ re.sub(chars_to_ignore_regex, "", batch[text_column_name]).replace("'", "").replace('"', "")
1065
+ )
1066
+
1067
+ raw_datasets[data_args.train_split_name] = raw_datasets[data_args.train_split_name].map(
1068
+ remove_punctuation,
1069
+ num_proc=data_args.preprocessing_num_workers,
1070
+ desc="removing punctuation from train split",
1071
+ )
1072
+
1073
+ # filter data where the targets are ignored in scoring
1074
+ def is_target_labels(input_str):
1075
+ return input_str.lower() not in ignore_segments
1076
+
1077
+ raw_datasets = raw_datasets.filter(
1078
+ is_target_labels,
1079
+ num_proc=num_workers,
1080
+ input_columns=[text_column_name],
1081
+ desc="filtering data where the targets are ignored in scoring",
1082
+ )
1083
+
1084
+ def prepare_dataset(batch):
1085
+ # process audio
1086
+ try:
1087
+ sample = batch[audio_column_name]
1088
+ except ValueError:
1089
+ sample = {"array": np.array([0.]), "sampling_rate": feature_extractor.sampling_rate}
1090
+ inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
1091
+ # process audio length
1092
+ batch[model_input_name] = inputs.input_values[0]
1093
+ batch["input_length"] = len(batch["input_values"])
1094
+
1095
+ # process targets
1096
+ input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
1097
+
1098
+ # if dataset_name == "google/xtreme_s":
1099
+ # # Finally, we tokenize the processed text
1100
+ # batch["labels"] = tokenizer(input_str).input_ids
1101
+ # batch["labels_length"] = len(batch["labels"])
1102
+ # return batch
1103
+
1104
+ # # Common Voice 9
1105
+ # if input_str.startswith('"') and input_str.endswith('"'):
1106
+ # # we can remove trailing quotation marks as they do not affect the transcription
1107
+ # input_str = input_str[1:-1]
1108
+ # # normalize quotation marks
1109
+ # input_str = re.sub(r'["“”]', '"', input_str)
1110
+ # # normalize apostrophes
1111
+ # input_str = re.sub(r"[’']", "'", input_str)
1112
+ # # normalize hyphens
1113
+ # input_str = re.sub(r"[—–]", "-", input_str)
1114
+ # # replace double quotation marks with single
1115
+ # input_str = input_str.replace('""', '"')
1116
+ # if dataset_name == "mozilla-foundation/common_voice_9_0" and len(input_str):
1117
+ # # for CV9, we'll normalize the text to always finish with punctuation
1118
+ # if input_str[-1] not in [".", "?", "!"]:
1119
+ # input_str = input_str + "."
1120
+
1121
+ # # TEDLIUM-3
1122
+ # # delete the <unk> token from the text and replace spaced apostrophes with un-spaced
1123
+ # input_str = input_str.replace("<unk>", "").replace(" '", "'")
1124
+
1125
+ # # GigaSpeech
1126
+ # for disfluency in gigaspeech_disfluencies:
1127
+ # input_str = input_str.replace(disfluency, "")
1128
+ # # convert spelled out punctuation to symbolic form
1129
+ # for punctuation, replacement in gigaspeech_punctuation.items():
1130
+ # input_str = input_str.replace(punctuation, replacement)
1131
+ # if dataset_name == "speechcolab/gigaspeech" and len(input_str):
1132
+ # # for GS, we'll normalize the text to always finish with punctuation
1133
+ # if input_str[-1] not in [".", "?", "!"]:
1134
+ # input_str = input_str + "."
1135
+
1136
+ # # SWB
1137
+ # for disfluency in swb_disfluencies:
1138
+ # input_str = input_str.replace(disfluency, "")
1139
+ # # remove parenthesised text (test data only)
1140
+ # input_str = re.sub("[\(].*?[\)]", "", input_str)
1141
+ # for punctuation in swb_punctuations:
1142
+ # input_str = input_str.replace(punctuation, "")
1143
+ # # replace anomalous words with their correct transcriptions
1144
+ # split_str = input_str.split("/")
1145
+ # if len(split_str) > 1:
1146
+ # input_str = " ".join(
1147
+ # [" ".join([" ".join(i.split(" ")[:-1]) for i in split_str])] + [split_str[-1].split(" ")[-1]])
1148
+
1149
+ # # Earnings 22
1150
+ # for disfluency in earnings_disfluencies:
1151
+ # input_str = input_str.replace(disfluency, "")
1152
+ # # replace mal-formatted ellipsis
1153
+ # input_str = input_str.replace("…", ".")
1154
+
1155
+ # JIWER compliance
1156
+ # remove multiple spaces
1157
+ input_str = re.sub(r"\s\s+", " ", input_str)
1158
+ # strip trailing spaces
1159
+ input_str = input_str.strip()
1160
+
1161
+ # Finally, we tokenize the processed text
1162
+ batch["labels"] = tokenizer(input_str).input_ids
1163
+ batch["labels_length"] = len(batch["labels"])
1164
+ return batch
1165
+
1166
+ vectorized_datasets = raw_datasets.map(
1167
+ prepare_dataset,
1168
+ remove_columns=next(iter(raw_datasets.values())).column_names,
1169
+ num_proc=num_workers,
1170
+ desc="preprocess dataset",
1171
+ )
1172
+
1173
+ # filter data with inputs shorter than min_input_length or longer than max_input_length
1174
+ def is_audio_in_length_range(length):
1175
+ return length > min_input_length and length < max_input_length
1176
+
1177
+ vectorized_datasets = vectorized_datasets.filter(
1178
+ is_audio_in_length_range,
1179
+ num_proc=num_workers,
1180
+ input_columns=["input_length"],
1181
+ )
1182
+
1183
+ # filter data with targets shorter than min_target_length or longer than max_target_length
1184
+ def is_labels_in_length_range(length):
1185
+ return length > min_target_length # and length < max_target_length
1186
+
1187
+ vectorized_datasets = vectorized_datasets.filter(
1188
+ is_labels_in_length_range,
1189
+ num_proc=num_workers,
1190
+ input_columns=["labels_length"],
1191
+ )
1192
+
1193
+ # for large datasets it is advised to run the preprocessing on a
1194
+ # single machine first with `args.preprocessing_only` since there will mostly likely
1195
+ # be a timeout when running the script in distributed mode.
1196
+ # In a second step `args.preprocessing_only` can then be set to `False` to load the
1197
+ # cached dataset
1198
+ if data_args.preprocessing_only:
1199
+ cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
1200
+ logger.info(f"Data preprocessing finished. Files cached at {cache}.")
1201
+ return
1202
+
1203
+ # 8. Load Metrics
1204
+ wer_metric = load_metric("wer")
1205
+ cer_metric = load_metric("cer")
1206
+
1207
+ def compute_metrics(pred_ids: List[List[int]], label_ids: List[List[int]]):
1208
+ padded_ids = np.where(np.asarray(label_ids) == -100, tokenizer.pad_token_id, np.asarray(label_ids))
1209
+
1210
+ pred_str = tokenizer.batch_decode(pred_ids)
1211
+ # we do not want to group tokens when computing the metrics
1212
+ label_str = tokenizer.batch_decode(padded_ids, group_tokens=False)
1213
+
1214
+ wer = wer_metric.compute(predictions=pred_str, references=label_str)
1215
+ cer = cer_metric.compute(predictions=pred_str, references=label_str)
1216
+
1217
+ return {"wer": wer, "cer": cer}, pred_str, label_str
1218
+
1219
+ # 9. save feature extractor, tokenizer and config
1220
+ feature_extractor.save_pretrained(training_args.output_dir)
1221
+ tokenizer.save_pretrained(training_args.output_dir)
1222
+ config.save_pretrained(training_args.output_dir)
1223
+
1224
+ processor = AutoProcessor.from_pretrained(training_args.output_dir)
1225
+
1226
+ data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding(
1227
+ processor=processor,
1228
+ input_padding="longest",
1229
+ pad_input_to_multiple_of=pad_input_to_multiple_of,
1230
+ max_label_length=data_args.max_label_length,
1231
+ )
1232
+
1233
+ # Enable tensorboard only on the master node
1234
+ has_tensorboard = is_tensorboard_available()
1235
+ if has_tensorboard and jax.process_index() == 0:
1236
+ try:
1237
+ from flax.metrics.tensorboard import SummaryWriter
1238
+
1239
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
1240
+ except ImportError as ie:
1241
+ has_tensorboard = False
1242
+ logger.warning(
1243
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
1244
+ )
1245
+ else:
1246
+ logger.warning(
1247
+ "Unable to display metrics through TensorBoard because the package is not installed: "
1248
+ "Please run `pip install tensorboard` to enable."
1249
+ )
1250
+
1251
+ # 10. Handle the repository creation
1252
+ if training_args.push_to_hub:
1253
+ with open(os.path.join(training_args.output_dir, ".gitattributes"), "r+") as f:
1254
+ git_lfs_extensions = f.read()
1255
+ if "*.wandb" not in git_lfs_extensions:
1256
+ f.write("*.wandb filter=lfs diff=lfs merge=lfs -text")
1257
+ if training_args.hub_model_id is None:
1258
+ repo_name = get_full_repo_name(
1259
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
1260
+ )
1261
+ else:
1262
+ repo_name = training_args.hub_model_id
1263
+ repo = Repository(training_args.output_dir, clone_from=repo_name)
1264
+
1265
+ # 11. Initialize our training
1266
+ rng = jax.random.PRNGKey(training_args.seed)
1267
+ rng, dropout_rng = jax.random.split(rng)
1268
+
1269
+ # Store some constants
1270
+ max_steps = int(training_args.max_steps)
1271
+ gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
1272
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
1273
+ batch_size_per_update = train_batch_size * gradient_accumulation_steps
1274
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
1275
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
1276
+ to_dtype = to_bf16 if training_args.mixed_precision else to_fp32
1277
+
1278
+ if training_args.do_train:
1279
+ num_train_samples = len(vectorized_datasets[data_args.train_split_name])
1280
+ steps_per_epoch = num_train_samples // batch_size_per_update
1281
+ if max_steps > 0:
1282
+ num_epochs = -(training_args.max_steps // -steps_per_epoch)
1283
+ total_train_steps = max_steps
1284
+ else:
1285
+ num_epochs = int(training_args.num_train_epochs)
1286
+ total_train_steps = steps_per_epoch * num_epochs
1287
+
1288
+ # Create learning rate schedule
1289
+ # Create learning rate schedule
1290
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
1291
+ total_train_steps,
1292
+ training_args.warmup_steps,
1293
+ training_args.learning_rate,
1294
+ )
1295
+
1296
+ # We use Optax's "masking" functionality to not apply weight decay
1297
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
1298
+ # mask boolean with the same structure as the parameters.
1299
+ # The mask is True for parameters that should be decayed.
1300
+ # Note that this mask is specifically adapted for FlaxWav2Vec2 and FlaxBart.
1301
+ # For FlaxT5, one should correct the layer norm parameter naming
1302
+ # accordingly - see `run_t5_mlm_flax.py` e.g.
1303
+ def decay_mask_fn(params):
1304
+ flat_params = traverse_util.flatten_dict(params)
1305
+ layer_norm_params = [
1306
+ (name, "scale")
1307
+ for name in ["layer_norm", "self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
1308
+ ]
1309
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
1310
+ return traverse_util.unflatten_dict(flat_mask)
1311
+
1312
+ if training_args.adafactor:
1313
+ # Create Adafactor optimizer
1314
+ optim = optax.adafactor(
1315
+ learning_rate=linear_decay_lr_schedule_fn,
1316
+ dtype_momentum=jnp.bfloat16 if training_args.mixed_precision else jnp.float32,
1317
+ weight_decay_rate=training_args.weight_decay,
1318
+ weight_decay_mask=decay_mask_fn,
1319
+ )
1320
+ else:
1321
+ # Create AdamW optimizer
1322
+ optim = optax.adamw(
1323
+ learning_rate=linear_decay_lr_schedule_fn,
1324
+ b1=training_args.adam_beta1,
1325
+ b2=training_args.adam_beta2,
1326
+ eps=training_args.adam_epsilon,
1327
+ weight_decay=training_args.weight_decay,
1328
+ mask=decay_mask_fn,
1329
+ )
1330
+
1331
+ # Optax MultiSteps for gradient accumulation. We'll only call this optimizer transformation if gradient accumulation is required (i.e. gradient accumulation steps > 1)
1332
+ if training_args.multisteps and gradient_accumulation_steps > 1:
1333
+ optim = optax.MultiSteps(optim, gradient_accumulation_steps, use_grad_mean=False)
1334
+ else:
1335
+ num_epochs = 0
1336
+ total_train_steps = 0
1337
+ num_train_samples = 0
1338
+ optim = None
1339
+
1340
+ # Setup train state
1341
+ state = MixedPrecisionTrainState.create(
1342
+ apply_fn=model.__call__,
1343
+ get_attention_mask_fn=model._get_feature_vector_attention_mask,
1344
+ params=model.params,
1345
+ tx=optim,
1346
+ to_dtype=to_dtype,
1347
+ dropout_rng=dropout_rng,
1348
+ max_grad_norm=training_args.max_grad_norm,
1349
+ )
1350
+
1351
+ # Replicate the train state on each device
1352
+ state = state.replicate()
1353
+ blank_id = model.config.pad_token_id
1354
+
1355
+ # Define gradient update step fn
1356
+ def train_step(state, batch):
1357
+ # only one single rng per grad step, with or without accumulation, as the graph should be identical over one effective training batch
1358
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
1359
+
1360
+ def compute_loss(params, minibatch):
1361
+ labels = minibatch.pop("labels")
1362
+ logits = state.apply_fn(
1363
+ **minibatch,
1364
+ params=params,
1365
+ dropout_rng=dropout_rng,
1366
+ freeze_feature_encoder=model_args.freeze_feature_encoder,
1367
+ train=True,
1368
+ )[0]
1369
+ logits_mask = state.get_attention_mask_fn(logits.shape[1], batch["attention_mask"])
1370
+ loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean")
1371
+
1372
+ return loss
1373
+
1374
+ grad_fn = jax.value_and_grad(compute_loss)
1375
+
1376
+ if gradient_accumulation_steps == 1 or training_args.multisteps:
1377
+ loss, grad = grad_fn(to_dtype(state.params), batch)
1378
+
1379
+ # Custom gradient accumulation
1380
+ else:
1381
+ # add a first dimension over gradient_accumulation_steps for minibatch slices
1382
+ batch = jax.tree_util.tree_map(
1383
+ lambda x: x.reshape(
1384
+ gradient_accumulation_steps, training_args.per_device_train_batch_size, *x.shape[1::]
1385
+ ),
1386
+ batch,
1387
+ )
1388
+
1389
+ def accum_minibatch_step(accum_grad, minibatch):
1390
+ # compute loss, num labels and grad over minibatch and accumulate
1391
+ loss, grad = grad_fn(to_dtype(state.params), minibatch)
1392
+ return jax.tree_util.tree_map(jnp.add, accum_grad, grad), loss
1393
+
1394
+ # create an initial state for accumulating losses, num labels and gradients
1395
+ init_grad = jax.tree_util.tree_map(jnp.zeros_like, to_dtype(state.params))
1396
+ # loop accum minibatch step over the number of gradient accumulation steps
1397
+ grad, loss = jax.lax.scan(accum_minibatch_step, init_grad, batch)
1398
+
1399
+ # update state
1400
+ new_state = state.apply_gradients(
1401
+ grads=grad,
1402
+ dropout_rng=new_dropout_rng,
1403
+ to_dtype=to_dtype,
1404
+ )
1405
+
1406
+ # compute gradient norms over all layers and globally for detailed monitoring
1407
+ layer_grad_norm = jax.tree_util.tree_map(jnp.linalg.norm, grad)
1408
+ logs = {
1409
+ "layer_grad_norm": layer_grad_norm,
1410
+ "grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm)),
1411
+ }
1412
+
1413
+ # compute parameter norms over all layers and globally for detailed monitoring
1414
+ layer_param_norm = jax.tree_util.tree_map(jnp.linalg.norm, new_state.params)
1415
+ logs["layer_param_norm"] = layer_param_norm
1416
+ logs["param_norm"] = jnp.linalg.norm(jax.tree_util.tree_leaves(layer_param_norm))
1417
+
1418
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
1419
+ metrics.update(logs)
1420
+
1421
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1422
+ # metrics = to_fp32(metrics)
1423
+
1424
+ return new_state, metrics
1425
+
1426
+ # Define eval fn
1427
+ def eval_step(params, batch):
1428
+ labels = batch.pop("labels")
1429
+ logits = model(**batch, params=params, train=False)[0]
1430
+
1431
+ logits_mask = model._get_feature_vector_attention_mask(logits.shape[1], batch["attention_mask"])
1432
+ loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean")
1433
+
1434
+ pred_ids = jnp.argmax(logits, axis=-1)
1435
+
1436
+ # summarize metrics
1437
+ metrics = {"loss": loss}
1438
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1439
+ # metrics = to_fp32(metrics)
1440
+ return metrics, pred_ids
1441
+
1442
+ # Create parallel version of the train and eval step
1443
+ if training_args.do_train:
1444
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
1445
+
1446
+ if training_args.do_eval:
1447
+ p_eval_step = jax.pmap(eval_step, "batch")
1448
+
1449
+ def run_evaluation(step):
1450
+ if training_args.do_eval:
1451
+ # ======================== Evaluating ==============================
1452
+ eval_metrics = []
1453
+ eval_preds = []
1454
+ eval_labels = []
1455
+
1456
+ # Generate eval set by sequentially sampling indices from the eval dataset and grouping by length
1457
+ eval_samples_idx = get_grouped_indices(vectorized_datasets[data_args.eval_split_name], eval_batch_size)
1458
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
1459
+
1460
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
1461
+ samples = [vectorized_datasets[data_args.eval_split_name][int(idx)] for idx in batch_idx]
1462
+ batch = data_collator(samples)
1463
+ labels = batch["labels"]
1464
+
1465
+ metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1466
+ eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1])))
1467
+ eval_metrics.append(metrics)
1468
+
1469
+ eval_labels.extend(labels)
1470
+
1471
+ # normalize eval metrics
1472
+ eval_metrics = get_metrics(eval_metrics)
1473
+ eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
1474
+ eval_metrics = to_fp32(eval_metrics)
1475
+
1476
+ # always run compute metrics
1477
+ error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels)
1478
+ eval_metrics.update(error_rate_metric)
1479
+ error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()])
1480
+
1481
+ # Print metrics and update progress bar
1482
+ desc = f"Step... ({step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})"
1483
+ epochs.write(desc)
1484
+ epochs.desc = desc
1485
+
1486
+ # Save metrics
1487
+ write_wandb_log(eval_metrics, step, prefix="eval")
1488
+ write_wandb_pred(pred_str, label_str, step)
1489
+ # if has_tensorboard and jax.process_index() == 0:
1490
+ # write_eval_metric(summary_writer, eval_metrics, step, pred_str=pred_str)
1491
+
1492
+ def save_checkpoint(step):
1493
+ # save and push checkpoint to the hub
1494
+ if jax.process_index() == 0:
1495
+ params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
1496
+ model.save_pretrained(training_args.output_dir, params=params)
1497
+ tokenizer.save_pretrained(training_args.output_dir)
1498
+ if training_args.push_to_hub:
1499
+ repo.push_to_hub(commit_message=f"{wandb.run.id}: saving weights and logs of step {int(step / 1000)}k", blocking=False)
1500
+
1501
+ skip_epochs = data_args.skip_steps // (num_train_samples // batch_size_per_update)
1502
+ logger.info("***** Running training *****")
1503
+ logger.info(f" Num examples = {num_train_samples}")
1504
+ logger.info(f" Num Epochs = {num_epochs}")
1505
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
1506
+ logger.info(f" Num gradient accumulation steps = {gradient_accumulation_steps}")
1507
+ logger.info(f" Total train batch size (w. parallel & distributed) = {batch_size_per_update}")
1508
+ logger.info(f" Total optimization steps = {total_train_steps}")
1509
+ logger.info(f" Gradient checkpointing: {config.gradient_checkpointing}")
1510
+ logger.info(f" Use scan: {config.use_scan}")
1511
+ logger.info(f" Fuse matmuls: {config.fuse_matmuls}")
1512
+ logger.info(f" Skipping: {data_args.skip_steps} steps ({skip_epochs} epochs)")
1513
+
1514
+ train_time = cur_step = 0
1515
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
1516
+ for epoch in epochs:
1517
+ if training_args.do_train:
1518
+ # ======================== Training ================================
1519
+ train_start = time.time()
1520
+
1521
+ if epoch < skip_epochs:
1522
+ logger.info(f"Skipping epoch {epoch + 1}")
1523
+ continue
1524
+
1525
+ # Create sampling rng
1526
+ rng, input_rng = jax.random.split(rng)
1527
+
1528
+ # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length
1529
+ train_samples_idx = get_grouped_indices(vectorized_datasets[data_args.train_split_name], batch_size_per_update, input_rng)
1530
+ train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update)
1531
+
1532
+ if data_args.skip_steps > cur_step:
1533
+ logger.info(f"Skipping {data_args.skip_steps - (epoch * (num_train_samples // batch_size_per_update))} steps...")
1534
+ # Gather the indices for creating the batch and do a training step
1535
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1), 1):
1536
+ cur_step = epoch * (num_train_samples // batch_size_per_update) + step
1537
+ if cur_step <= data_args.skip_steps:
1538
+ continue
1539
+
1540
+ samples = [vectorized_datasets[data_args.train_split_name][int(idx)] for idx in batch_idx]
1541
+ batch = data_collator(samples)
1542
+ batch = shard(batch.data)
1543
+ try:
1544
+ state, train_metric = p_train_step(state, batch)
1545
+ except TypeError as e:
1546
+ logger.warning("Encountered following error: \n", e)
1547
+
1548
+
1549
+ if cur_step % training_args.logging_steps == 0:
1550
+ # Save metrics
1551
+ train_metric = unreplicate(train_metric)
1552
+ train_time += time.time() - train_start
1553
+ # need to upcast all device arrays to fp32 for wandb logging (jnp.bfloat16 not supported) -> do this here OR in train_step
1554
+ write_wandb_log(to_fp32(train_metric), cur_step, prefix=data_args.train_split_name)
1555
+ # we won't log to tensorboard for now (it is fiddly logging param and grad norms on a layer-by-layer basis)
1556
+ # if has_tensorboard and jax.process_index() == 0:
1557
+ # write_train_metric(summary_writer, train_metrics, train_time, cur_step)
1558
+
1559
+ epochs.write(
1560
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']}, Gradient Norm: {train_metric['grad_norm']})"
1561
+ )
1562
+
1563
+ if cur_step % total_train_steps == 0:
1564
+ break
1565
+
1566
+ if training_args.eval_steps and cur_step % training_args.eval_steps == 0:
1567
+ run_evaluation(cur_step)
1568
+
1569
+ if cur_step % training_args.save_steps == 0:
1570
+ save_checkpoint(cur_step)
1571
+
1572
+ if training_args.eval_steps == 0 and (epoch + 1) != num_epochs:
1573
+ # run evaluation at the end of the epoch if eval steps are not specified
1574
+ run_evaluation(cur_step)
1575
+ save_checkpoint(cur_step)
1576
+
1577
+ if training_args.do_train:
1578
+ save_checkpoint(cur_step)
1579
+
1580
+ cur_step = max_steps if max_steps > 0 else cur_step # set step to max steps so that eval happens in alignment with training
1581
+
1582
+ if training_args.do_eval:
1583
+ run_evaluation(cur_step)
1584
+
1585
+ # TODO: collapse 'do_predict' into the run_evaluation function
1586
+ if training_args.do_predict:
1587
+ for split in [data_args.test_split_name]:
1588
+ # ======================== Evaluating ==============================
1589
+ eval_metrics = []
1590
+ eval_preds = []
1591
+ eval_labels = []
1592
+
1593
+ # Generate eval set by sequentially sampling indices from the test dataset and grouping by length
1594
+ eval_samples_idx = get_grouped_indices(vectorized_datasets[split], eval_batch_size)
1595
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
1596
+
1597
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc=f"Predicting {split}...", position=2)):
1598
+ samples = [vectorized_datasets[split][int(idx)] for idx in batch_idx]
1599
+ batch = data_collator(samples)
1600
+ labels = batch["labels"]
1601
+
1602
+ metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1603
+ eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1])))
1604
+ eval_metrics.append(metrics)
1605
+
1606
+ eval_labels.extend(labels)
1607
+
1608
+ # normalize eval metrics
1609
+ eval_metrics = get_metrics(eval_metrics)
1610
+ eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
1611
+ eval_metrics = to_fp32(eval_metrics)
1612
+
1613
+ # always run compute metrics
1614
+ error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels)
1615
+ eval_metrics.update(error_rate_metric)
1616
+ error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()])
1617
+
1618
+ # Print metrics and update progress bar
1619
+ desc = f"Step... ({cur_step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})"
1620
+ epochs.write(desc)
1621
+ epochs.desc = desc
1622
+
1623
+ # Save metrics
1624
+ write_wandb_log(eval_metrics, cur_step, prefix=split)
1625
+ write_wandb_pred(pred_str, label_str, cur_step, prefix=split)
1626
+ # if has_tensorboard and jax.process_index() == 0:
1627
+ # write_eval_metric(summary_writer, eval_metrics, cur_step, pred_str=pred_str)
1628
+
1629
+
1630
+ if __name__ == "__main__":
1631
+ main()
wandb/run-20220805_230151-2y71vcu4/files/config.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_version: 1
2
+
3
+ _wandb:
4
+ desc: null
5
+ value:
6
+ cli_version: 0.12.9
7
+ code_path: code/run_flax_speech_recognition_ctc.py
8
+ framework: huggingface
9
+ huggingface_version: 4.21.0
10
+ is_jupyter_run: false
11
+ is_kaggle_kernel: false
12
+ python_version: 3.8.10
13
+ start_time: 1659740511
14
+ t:
15
+ 1:
16
+ - 1
17
+ - 2
18
+ - 3
19
+ - 11
20
+ - 12
21
+ 3:
22
+ - 13
23
+ 4: 3.8.10
24
+ 5: 0.12.9
25
+ 6: 4.21.0
26
+ 8:
27
+ - 5
wandb/run-20220805_230151-2y71vcu4/files/diff.patch ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/config.json b/config.json
2
+ index 260219f..246b797 100644
3
+ --- a/config.json
4
+ +++ b/config.json
5
+ @@ -5,7 +5,7 @@
6
+ "add_adapter": false,
7
+ "apply_spec_augment": true,
8
+ "architectures": [
9
+ - "Wav2Vec2ForCTC"
10
+ + "Wav2Vec2ForPreTraining"
11
+ ],
12
+ "attention_dropout": 0.094,
13
+ "bos_token_id": 1,
14
+ diff --git a/run.sh b/run.sh
15
+ index 9cc498e..8758978 100755
16
+ --- a/run.sh
17
+ +++ b/run.sh
18
+ @@ -1,6 +1,6 @@
19
+ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_ctc.py \
20
+ --model_name_or_path="facebook/wav2vec2-xls-r-1b" \
21
+ - --hub_model_id="NbAiLab/wav2vec2-1b-npsc-nst" \
22
+ + --hub_model_id="NbAiLab/wav2vec2-1b-npsc-nst-tpu" \
23
+ --tokenizer_name="./" \
24
+ --output_dir="./" \
25
+ --overwrite_output_dir \
26
+ @@ -11,13 +11,13 @@ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_c
27
+ --precision="full_mixed" \
28
+ --matmul_precision="bfloat16" \
29
+ --multisteps \
30
+ - --learning_rate="2e-4" \
31
+ + --learning_rate="1e-4" \
32
+ --warmup_steps="2000" \
33
+ --length_column_name="input_length" \
34
+ --evaluation_strategy="steps" \
35
+ --text_column_name="text" \
36
+ - --save_steps="4000" \
37
+ - --eval_steps="4000" \
38
+ + --save_steps="5000" \
39
+ + --eval_steps="5000" \
40
+ --logging_steps="100" \
41
+ --layerdrop="0.041" \
42
+ --attention_dropout="0.094" \
43
+ @@ -42,7 +42,7 @@ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_c
44
+ --ctc_zero_infinity \
45
+ --do_lower_case \
46
+ --wandb_project="wav2vec2" \
47
+ - --wandb_name="wav2vec2-1b-npsc-nst" \
48
+ + --wandb_name="wav2vec2-1b-npsc-nst-tpu" \
49
+ --remove_punctuation
50
+
51
+
52
+ diff --git a/special_tokens_map.json b/special_tokens_map.json
53
+ index 89389bf..308786b 100644
54
+ --- a/special_tokens_map.json
55
+ +++ b/special_tokens_map.json
56
+ @@ -343,6 +343,48 @@
57
+ "rstrip": false,
58
+ "single_word": false
59
+ },
60
+ + {
61
+ + "content": "</s>",
62
+ + "lstrip": false,
63
+ + "normalized": true,
64
+ + "rstrip": false,
65
+ + "single_word": false
66
+ + },
67
+ + {
68
+ + "content": "<s>",
69
+ + "lstrip": false,
70
+ + "normalized": true,
71
+ + "rstrip": false,
72
+ + "single_word": false
73
+ + },
74
+ + {
75
+ + "content": "</s>",
76
+ + "lstrip": false,
77
+ + "normalized": true,
78
+ + "rstrip": false,
79
+ + "single_word": false
80
+ + },
81
+ + {
82
+ + "content": "<s>",
83
+ + "lstrip": false,
84
+ + "normalized": true,
85
+ + "rstrip": false,
86
+ + "single_word": false
87
+ + },
88
+ + {
89
+ + "content": "</s>",
90
+ + "lstrip": false,
91
+ + "normalized": true,
92
+ + "rstrip": false,
93
+ + "single_word": false
94
+ + },
95
+ + {
96
+ + "content": "<s>",
97
+ + "lstrip": false,
98
+ + "normalized": true,
99
+ + "rstrip": false,
100
+ + "single_word": false
101
+ + },
102
+ {
103
+ "content": "</s>",
104
+ "lstrip": false,
105
+ diff --git a/wandb/debug-internal.log b/wandb/debug-internal.log
106
+ index 50a0b69..23926ef 120000
107
+ --- a/wandb/debug-internal.log
108
+ +++ b/wandb/debug-internal.log
109
+ @@ -1 +1 @@
110
+ -run-20220803_091109-yit1e59z/logs/debug-internal.log
111
+
112
+ +run-20220805_230151-2y71vcu4/logs/debug-internal.log
113
+
114
+ diff --git a/wandb/debug.log b/wandb/debug.log
115
+ index 746223d..279853d 120000
116
+ --- a/wandb/debug.log
117
+ +++ b/wandb/debug.log
118
+ @@ -1 +1 @@
119
+ -run-20220803_091109-yit1e59z/logs/debug.log
120
+
121
+ +run-20220805_230151-2y71vcu4/logs/debug.log
122
+
123
+ diff --git a/wandb/latest-run b/wandb/latest-run
124
+ index be58b40..f069a7a 120000
125
+ --- a/wandb/latest-run
126
+ +++ b/wandb/latest-run
127
+ @@ -1 +1 @@
128
+ -run-20220803_091109-yit1e59z
129
+
130
+ +run-20220805_230151-2y71vcu4
131
+
wandb/run-20220805_230151-2y71vcu4/files/output.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:864734adde2a9711d54e4c8f0ab743d16890f9d9702b4457b261c095c32ea5e0
3
+ size 239534
wandb/run-20220805_230151-2y71vcu4/files/requirements.txt ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.0.0
2
+ aiohttp==3.8.1
3
+ aiosignal==1.2.0
4
+ appdirs==1.4.4
5
+ astunparse==1.6.3
6
+ async-timeout==4.0.2
7
+ attrs==21.4.0
8
+ audioread==2.1.9
9
+ backcall==0.2.0
10
+ cachetools==4.2.4
11
+ certifi==2021.10.8
12
+ cffi==1.15.1
13
+ charset-normalizer==2.0.10
14
+ chex==0.1.3
15
+ click==8.0.3
16
+ cloud-tpu-client==0.10
17
+ cloud-tpu-profiler==2.4.0
18
+ clu==0.0.6
19
+ colorama==0.4.5
20
+ commonmark==0.9.1
21
+ configparser==5.2.0
22
+ contextlib2==21.6.0
23
+ cycler==0.11.0
24
+ datasets==2.4.0
25
+ decorator==5.1.0
26
+ dill==0.3.4
27
+ dm-tree==0.1.6
28
+ docker-pycreds==0.4.0
29
+ etils==0.6.0
30
+ exceptiongroup==1.0.0rc8
31
+ filelock==3.4.2
32
+ flatbuffers==2.0
33
+ flax==0.5.3
34
+ fonttools==4.28.5
35
+ frozenlist==1.2.0
36
+ fsspec==2021.11.1
37
+ future==0.18.2
38
+ gast==0.4.0
39
+ gitdb==4.0.9
40
+ gitpython==3.1.26
41
+ google-api-core==1.31.5
42
+ google-api-python-client==1.8.0
43
+ google-auth-httplib2==0.1.0
44
+ google-auth-oauthlib==0.4.6
45
+ google-auth==2.3.3
46
+ google-pasta==0.2.0
47
+ googleapis-common-protos==1.54.0
48
+ grpcio==1.43.0
49
+ h5py==3.6.0
50
+ httplib2==0.20.2
51
+ huggingface-hub==0.2.1
52
+ hypothesis==6.53.0
53
+ idna==3.3
54
+ importlib-metadata==4.10.0
55
+ importlib-resources==5.4.0
56
+ ipython==7.31.0
57
+ jax==0.3.15
58
+ jaxlib==0.3.15
59
+ jedi==0.18.1
60
+ jiwer==2.3.0
61
+ joblib==1.1.0
62
+ keras-preprocessing==1.1.2
63
+ keras==2.7.0
64
+ kiwisolver==1.3.2
65
+ libclang==12.0.0
66
+ librosa==0.9.2
67
+ libtpu-nightly==0.1.dev20220722
68
+ llvmlite==0.39.0
69
+ markdown==3.3.6
70
+ matplotlib-inline==0.1.3
71
+ matplotlib==3.5.1
72
+ ml-collections==0.1.0
73
+ msgpack==1.0.3
74
+ multidict==5.2.0
75
+ multiprocess==0.70.12.2
76
+ numba==0.56.0
77
+ numpy==1.22.0
78
+ oauth2client==4.1.3
79
+ oauthlib==3.1.1
80
+ opt-einsum==3.3.0
81
+ optax==0.1.3
82
+ packaging==21.3
83
+ pandas==1.3.5
84
+ parso==0.8.3
85
+ pathtools==0.1.2
86
+ pexpect==4.8.0
87
+ pickleshare==0.7.5
88
+ pillow==9.0.0
89
+ pip==22.2.1
90
+ pkg-resources==0.0.0
91
+ pooch==1.6.0
92
+ promise==2.3
93
+ prompt-toolkit==3.0.24
94
+ protobuf==3.19.1
95
+ psutil==5.9.0
96
+ ptyprocess==0.7.0
97
+ pyarrow==6.0.1
98
+ pyasn1-modules==0.2.8
99
+ pyasn1==0.4.8
100
+ pycparser==2.21
101
+ pyctcdecode==0.4.0
102
+ pygments==2.11.1
103
+ pygtrie==2.5.0
104
+ pyparsing==3.0.6
105
+ python-dateutil==2.8.2
106
+ python-levenshtein==0.12.2
107
+ pytz==2021.3
108
+ pyyaml==6.0
109
+ regex==2021.11.10
110
+ requests-oauthlib==1.3.0
111
+ requests==2.27.0
112
+ resampy==0.3.1
113
+ responses==0.18.0
114
+ rich==11.2.0
115
+ rsa==4.8
116
+ sacremoses==0.0.46
117
+ scikit-learn==1.1.1
118
+ scipy==1.7.3
119
+ sentry-sdk==1.5.2
120
+ setuptools==44.0.0
121
+ shortuuid==1.0.8
122
+ six==1.16.0
123
+ smmap==5.0.0
124
+ sortedcontainers==2.4.0
125
+ soundfile==0.10.3.post1
126
+ sox==1.4.1
127
+ subprocess32==3.5.4
128
+ tensorboard-data-server==0.6.1
129
+ tensorboard-plugin-wit==1.8.0
130
+ tensorboard==2.7.0
131
+ tensorflow-cpu==2.7.0
132
+ tensorflow-datasets==4.4.0
133
+ tensorflow-estimator==2.7.0
134
+ tensorflow-io-gcs-filesystem==0.23.1
135
+ tensorflow-metadata==1.5.0
136
+ tensorflow==2.7.0
137
+ tensorstore==0.1.21
138
+ termcolor==1.1.0
139
+ threadpoolctl==3.1.0
140
+ tokenizers==0.11.2
141
+ toolz==0.11.2
142
+ torch==1.12.0
143
+ torchaudio==0.12.0+cpu
144
+ tqdm==4.62.3
145
+ traitlets==5.1.1
146
+ transformers==4.21.0
147
+ typing-extensions==4.3.0
148
+ uritemplate==3.0.1
149
+ urllib3==1.26.7
150
+ wandb==0.12.9
151
+ wcwidth==0.2.5
152
+ werkzeug==2.0.2
153
+ wheel==0.37.1
154
+ wrapt==1.13.3
155
+ xxhash==2.0.2
156
+ yarl==1.7.2
157
+ yaspin==2.1.0
158
+ zipp==3.7.0
wandb/run-20220805_230151-2y71vcu4/files/wandb-metadata.json ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.4.0-1043-gcp-x86_64-with-glibc2.29",
3
+ "python": "3.8.10",
4
+ "heartbeatAt": "2022-08-05T23:01:55.087413",
5
+ "startedAt": "2022-08-05T23:01:51.745697",
6
+ "docker": null,
7
+ "cpu_count": 96,
8
+ "cuda": null,
9
+ "args": [
10
+ "--model_name_or_path=facebook/wav2vec2-xls-r-1b",
11
+ "--hub_model_id=NbAiLab/wav2vec2-1b-npsc-nst-tpu",
12
+ "--tokenizer_name=./",
13
+ "--output_dir=./",
14
+ "--overwrite_output_dir",
15
+ "--num_train_epochs=40",
16
+ "--per_device_train_batch_size=2",
17
+ "--per_device_eval_batch_size=2",
18
+ "--gradient_accumulation_steps=1",
19
+ "--precision=full_mixed",
20
+ "--matmul_precision=bfloat16",
21
+ "--multisteps",
22
+ "--learning_rate=1e-4",
23
+ "--warmup_steps=2000",
24
+ "--length_column_name=input_length",
25
+ "--evaluation_strategy=steps",
26
+ "--text_column_name=text",
27
+ "--save_steps=5000",
28
+ "--eval_steps=5000",
29
+ "--logging_steps=100",
30
+ "--layerdrop=0.041",
31
+ "--attention_dropout=0.094",
32
+ "--activation_dropout=0.055",
33
+ "--hidden_dropout=0.047",
34
+ "--save_total_limit=5",
35
+ "--freeze_feature_encoder",
36
+ "--feat_proj_dropout=0.04",
37
+ "--mask_time_prob=0.082",
38
+ "--mask_time_length=10",
39
+ "--mask_feature_prob=0.25",
40
+ "--mask_feature_length=64",
41
+ "--gradient_checkpointing",
42
+ "--min_duration_in_seconds=0.5",
43
+ "--max_duration_in_seconds=30.0",
44
+ "--use_auth_token",
45
+ "--seed=42",
46
+ "--group_by_length",
47
+ "--do_train",
48
+ "--do_eval",
49
+ "--push_to_hub",
50
+ "--preprocessing_num_workers=32",
51
+ "--ctc_zero_infinity",
52
+ "--do_lower_case",
53
+ "--wandb_project=wav2vec2",
54
+ "--wandb_name=wav2vec2-1b-npsc-nst-tpu",
55
+ "--remove_punctuation"
56
+ ],
57
+ "state": "running",
58
+ "program": "run_flax_speech_recognition_ctc.py",
59
+ "codePath": "run_flax_speech_recognition_ctc.py",
60
+ "git": {
61
+ "remote": "https://huggingface.co/NbAiLab/wav2vec2-1b-npsc-nst-tpu",
62
+ "commit": "e2b1320cc68c3ce129a1d654965e0d3eb44e0558"
63
+ },
64
+ "email": "versae@gmail.com",
65
+ "root": "/data/wav2vec2-1b-npsc-nst",
66
+ "host": "t1v-n-eedfb410-w-0",
67
+ "username": "javierr",
68
+ "executable": "/data/flax/bin/python"
69
+ }
wandb/run-20220805_230151-2y71vcu4/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"train/grad_norm": 8.0625, "layer_grad_norm/": {"lm_head": {"bias": 0.140625, "kernel": 2.40625}, "wav2vec2": {"encoder": {"layer_norm": {"bias": 0.10791015625, "scale": 0.091796875}, "layers": {"0": {"attention": {"k_proj": {"bias": 0.00018215179443359375, "kernel": 0.138671875}, "out_proj": {"bias": 0.119140625, "kernel": 0.86328125}, "q_proj": {"bias": 0.0157470703125, "kernel": 0.1904296875}, "v_proj": {"bias": 0.1064453125, "kernel": 0.84765625}}, "feed_forward": {"intermediate_dense": {"bias": 0.10546875, "kernel": 1.515625}, "output_dense": {"bias": 0.03759765625, "kernel": 1.296875}}, "final_layer_norm": {"bias": 0.30078125, "scale": 0.48828125}, "layer_norm": {"bias": 0.181640625, "scale": 0.314453125}}, "1": {"attention": {"k_proj": {"bias": 0.00012874603271484375, "kernel": 0.123046875}, "out_proj": {"bias": 0.0419921875, "kernel": 0.5390625}, "q_proj": {"bias": 0.01007080078125, "kernel": 0.119140625}, "v_proj": {"bias": 0.0625, "kernel": 0.419921875}}, "feed_forward": {"intermediate_dense": {"bias": 0.051025390625, "kernel": 0.765625}, "output_dense": {"bias": 0.037841796875, "kernel": 0.65625}}, "final_layer_norm": {"bias": 0.095703125, "scale": 0.09326171875}, "layer_norm": {"bias": 0.10546875, "scale": 0.146484375}}, "10": {"attention": {"k_proj": {"bias": 8.726119995117188e-05, "kernel": 0.244140625}, "out_proj": {"bias": 0.02587890625, "kernel": 0.427734375}, "q_proj": {"bias": 0.0137939453125, "kernel": 0.2490234375}, "v_proj": {"bias": 0.040771484375, "kernel": 0.439453125}}, "feed_forward": {"intermediate_dense": {"bias": 0.039306640625, "kernel": 0.66015625}, "output_dense": {"bias": 0.024169921875, "kernel": 0.53125}}, "final_layer_norm": {"bias": 0.06640625, "scale": 0.0654296875}, "layer_norm": {"bias": 0.08251953125, "scale": 0.1181640625}}, "11": {"attention": {"k_proj": {"bias": 0.0001239776611328125, "kernel": 0.26953125}, "out_proj": {"bias": 0.0240478515625, "kernel": 0.50390625}, "q_proj": {"bias": 0.0166015625, "kernel": 0.26171875}, "v_proj": {"bias": 0.03955078125, "kernel": 0.515625}}, "feed_forward": {"intermediate_dense": {"bias": 0.033203125, "kernel": 0.58203125}, "output_dense": {"bias": 0.023193359375, "kernel": 0.4453125}}, "final_layer_norm": {"bias": 0.05322265625, "scale": 0.05712890625}, "layer_norm": {"bias": 0.0703125, "scale": 0.1064453125}}, "12": {"attention": {"k_proj": {"bias": 8.58306884765625e-05, "kernel": 0.197265625}, "out_proj": {"bias": 0.0234375, "kernel": 0.400390625}, "q_proj": {"bias": 0.00970458984375, "kernel": 0.19140625}, "v_proj": {"bias": 0.03662109375, "kernel": 0.431640625}}, "feed_forward": {"intermediate_dense": {"bias": 0.03369140625, "kernel": 0.56640625}, "output_dense": {"bias": 0.0225830078125, "kernel": 0.41796875}}, "final_layer_norm": {"bias": 0.057373046875, "scale": 0.05712890625}, "layer_norm": {"bias": 0.060546875, "scale": 0.059326171875}}, "13": {"attention": {"k_proj": {"bias": 0.00017261505126953125, "kernel": 0.259765625}, "out_proj": {"bias": 0.0235595703125, "kernel": 0.484375}, "q_proj": {"bias": 0.01336669921875, "kernel": 0.240234375}, "v_proj": {"bias": 0.039306640625, "kernel": 0.54296875}}, "feed_forward": {"intermediate_dense": {"bias": 0.032958984375, "kernel": 0.55078125}, "output_dense": {"bias": 0.0230712890625, "kernel": 0.42578125}}, "final_layer_norm": {"bias": 0.05517578125, "scale": 0.06201171875}, "layer_norm": {"bias": 0.06103515625, "scale": 0.06787109375}}, "14": {"attention": {"k_proj": {"bias": 0.000148773193359375, "kernel": 0.19140625}, "out_proj": {"bias": 0.0228271484375, "kernel": 0.4453125}, "q_proj": {"bias": 0.0096435546875, "kernel": 0.1787109375}, "v_proj": {"bias": 0.03466796875, "kernel": 0.45703125}}, "feed_forward": {"intermediate_dense": {"bias": 0.03369140625, "kernel": 0.5625}, "output_dense": {"bias": 0.022705078125, "kernel": 0.43359375}}, "final_layer_norm": {"bias": 0.0556640625, "scale": 0.061279296875}, "layer_norm": {"bias": 0.0537109375, "scale": 0.05419921875}}, "15": {"attention": {"k_proj": {"bias": 0.0001983642578125, "kernel": 0.2490234375}, "out_proj": {"bias": 0.0233154296875, "kernel": 0.5859375}, "q_proj": {"bias": 0.012939453125, "kernel": 0.224609375}, "v_proj": {"bias": 0.0400390625, "kernel": 0.5859375}}, "feed_forward": {"intermediate_dense": {"bias": 0.03271484375, "kernel": 0.53125}, "output_dense": {"bias": 0.0223388671875, "kernel": 0.4296875}}, "final_layer_norm": {"bias": 0.05419921875, "scale": 0.06494140625}, "layer_norm": {"bias": 0.060546875, "scale": 0.0625}}, "16": {"attention": {"k_proj": {"bias": 0.00013446807861328125, "kernel": 0.22265625}, "out_proj": {"bias": 0.0224609375, "kernel": 0.384765625}, "q_proj": {"bias": 0.011962890625, "kernel": 0.2109375}, "v_proj": {"bias": 0.03466796875, "kernel": 0.40234375}}, "feed_forward": {"intermediate_dense": {"bias": 0.03173828125, "kernel": 0.5234375}, "output_dense": {"bias": 0.021484375, "kernel": 0.421875}}, "final_layer_norm": {"bias": 0.054931640625, "scale": 0.0498046875}, "layer_norm": {"bias": 0.057861328125, "scale": 0.05517578125}}, "17": {"attention": {"k_proj": {"bias": 0.0001239776611328125, "kernel": 0.177734375}, "out_proj": {"bias": 0.023193359375, "kernel": 0.330078125}, "q_proj": {"bias": 0.00933837890625, "kernel": 0.154296875}, "v_proj": {"bias": 0.034912109375, "kernel": 0.37109375}}, "feed_forward": {"intermediate_dense": {"bias": 0.03125, "kernel": 0.51171875}, "output_dense": {"bias": 0.0220947265625, "kernel": 0.400390625}}, "final_layer_norm": {"bias": 0.0517578125, "scale": 0.0478515625}, "layer_norm": {"bias": 0.05615234375, "scale": 0.0546875}}, "18": {"attention": {"k_proj": {"bias": 0.000110626220703125, "kernel": 0.201171875}, "out_proj": {"bias": 0.02197265625, "kernel": 0.404296875}, "q_proj": {"bias": 0.0098876953125, "kernel": 0.1806640625}, "v_proj": {"bias": 0.03369140625, "kernel": 0.40625}}, "feed_forward": {"intermediate_dense": {"bias": 0.02978515625, "kernel": 0.4921875}, "output_dense": {"bias": 0.0208740234375, "kernel": 0.408203125}}, "final_layer_norm": {"bias": 0.049560546875, "scale": 0.04296875}, "layer_norm": {"bias": 0.052978515625, "scale": 0.047119140625}}, "19": {"attention": {"k_proj": {"bias": 6.341934204101562e-05, "kernel": 0.13671875}, "out_proj": {"bias": 0.021484375, "kernel": 0.291015625}, "q_proj": {"bias": 0.0076904296875, "kernel": 0.1298828125}, "v_proj": {"bias": 0.03125, "kernel": 0.32421875}}, "feed_forward": {"intermediate_dense": {"bias": 0.028076171875, "kernel": 0.478515625}, "output_dense": {"bias": 0.020751953125, "kernel": 0.41015625}}, "final_layer_norm": {"bias": 0.048583984375, "scale": 0.04296875}, "layer_norm": {"bias": 0.047119140625, "scale": 0.036376953125}}, "2": {"attention": {"k_proj": {"bias": 0.0001392364501953125, "kernel": 0.150390625}, "out_proj": {"bias": 0.04296875, "kernel": 0.56640625}, "q_proj": {"bias": 0.01214599609375, "kernel": 0.1484375}, "v_proj": {"bias": 0.06787109375, "kernel": 0.51953125}}, "feed_forward": {"intermediate_dense": {"bias": 0.057373046875, "kernel": 0.953125}, "output_dense": {"bias": 0.03759765625, "kernel": 0.6953125}}, "final_layer_norm": {"bias": 0.1103515625, "scale": 0.0927734375}, "layer_norm": {"bias": 0.099609375, "scale": 0.130859375}}, "20": {"attention": {"k_proj": {"bias": 3.62396240234375e-05, "kernel": 0.0966796875}, "out_proj": {"bias": 0.02197265625, "kernel": 0.2080078125}, "q_proj": {"bias": 0.0048828125, "kernel": 0.08984375}, "v_proj": {"bias": 0.030517578125, "kernel": 0.24609375}}, "feed_forward": {"intermediate_dense": {"bias": 0.02734375, "kernel": 0.486328125}, "output_dense": {"bias": 0.021484375, "kernel": 0.396484375}}, "final_layer_norm": {"bias": 0.044921875, "scale": 0.03955078125}, "layer_norm": {"bias": 0.04541015625, "scale": 0.0322265625}}, "21": {"attention": {"k_proj": {"bias": 9.584426879882812e-05, "kernel": 0.12158203125}, "out_proj": {"bias": 0.0218505859375, "kernel": 0.283203125}, "q_proj": {"bias": 0.006561279296875, "kernel": 0.11962890625}, "v_proj": {"bias": 0.031005859375, "kernel": 0.3125}}, "feed_forward": {"intermediate_dense": {"bias": 0.0277099609375, "kernel": 0.49609375}, "output_dense": {"bias": 0.0218505859375, "kernel": 0.40234375}}, "final_layer_norm": {"bias": 0.045166015625, "scale": 0.0439453125}, "layer_norm": {"bias": 0.042724609375, "scale": 0.037353515625}}, "22": {"attention": {"k_proj": {"bias": 5.507469177246094e-05, "kernel": 0.138671875}, "out_proj": {"bias": 0.0234375, "kernel": 0.26953125}, "q_proj": {"bias": 0.007659912109375, "kernel": 0.140625}, "v_proj": {"bias": 0.0322265625, "kernel": 0.2890625}}, "feed_forward": {"intermediate_dense": {"bias": 0.029296875, "kernel": 0.51171875}, "output_dense": {"bias": 0.0233154296875, "kernel": 0.400390625}}, "final_layer_norm": {"bias": 0.04736328125, "scale": 0.05126953125}, "layer_norm": {"bias": 0.05029296875, "scale": 0.06640625}}, "23": {"attention": {"k_proj": {"bias": 0.0001430511474609375, "kernel": 0.2080078125}, "out_proj": {"bias": 0.0242919921875, "kernel": 0.462890625}, "q_proj": {"bias": 0.010986328125, "kernel": 0.2109375}, "v_proj": {"bias": 0.03662109375, "kernel": 0.453125}}, "feed_forward": {"intermediate_dense": {"bias": 0.02978515625, "kernel": 0.5234375}, "output_dense": {"bias": 0.0238037109375, "kernel": 0.4140625}}, "final_layer_norm": {"bias": 0.048828125, "scale": 0.04638671875}, "layer_norm": {"bias": 0.05810546875, "scale": 0.06884765625}}, "24": {"attention": {"k_proj": {"bias": 9.441375732421875e-05, "kernel": 0.19140625}, "out_proj": {"bias": 0.0224609375, "kernel": 0.37109375}, "q_proj": {"bias": 0.01025390625, "kernel": 0.189453125}, "v_proj": {"bias": 0.035888671875, "kernel": 0.37890625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0289306640625, "kernel": 0.53125}, "output_dense": {"bias": 0.021728515625, "kernel": 0.390625}}, "final_layer_norm": {"bias": 0.0478515625, "scale": 0.045654296875}, "layer_norm": {"bias": 0.06005859375, "scale": 0.0517578125}}, "25": {"attention": {"k_proj": {"bias": 0.0001239776611328125, "kernel": 0.189453125}, "out_proj": {"bias": 0.0224609375, "kernel": 0.35546875}, "q_proj": {"bias": 0.0106201171875, "kernel": 0.1875}, "v_proj": {"bias": 0.033935546875, "kernel": 0.373046875}}, "feed_forward": {"intermediate_dense": {"bias": 0.02880859375, "kernel": 0.51953125}, "output_dense": {"bias": 0.02197265625, "kernel": 0.3671875}}, "final_layer_norm": {"bias": 0.049072265625, "scale": 0.04443359375}, "layer_norm": {"bias": 0.051025390625, "scale": 0.04296875}}, "26": {"attention": {"k_proj": {"bias": 7.62939453125e-05, "kernel": 0.181640625}, "out_proj": {"bias": 0.022216796875, "kernel": 0.345703125}, "q_proj": {"bias": 0.0108642578125, "kernel": 0.193359375}, "v_proj": {"bias": 0.034912109375, "kernel": 0.359375}}, "feed_forward": {"intermediate_dense": {"bias": 0.0279541015625, "kernel": 0.484375}, "output_dense": {"bias": 0.0220947265625, "kernel": 0.37890625}}, "final_layer_norm": {"bias": 0.0458984375, "scale": 0.04931640625}, "layer_norm": {"bias": 0.052978515625, "scale": 0.0693359375}}, "27": {"attention": {"k_proj": {"bias": 0.00016021728515625, "kernel": 0.24609375}, "out_proj": {"bias": 0.0205078125, "kernel": 0.4453125}, "q_proj": {"bias": 0.013916015625, "kernel": 0.2470703125}, "v_proj": {"bias": 0.034423828125, "kernel": 0.44140625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0286865234375, "kernel": 0.48828125}, "output_dense": {"bias": 0.02099609375, "kernel": 0.3828125}}, "final_layer_norm": {"bias": 0.0478515625, "scale": 0.05126953125}, "layer_norm": {"bias": 0.05859375, "scale": 0.099609375}}, "28": {"attention": {"k_proj": {"bias": 0.00012063980102539062, "kernel": 0.2119140625}, "out_proj": {"bias": 0.019287109375, "kernel": 0.43359375}, "q_proj": {"bias": 0.0130615234375, "kernel": 0.228515625}, "v_proj": {"bias": 0.03076171875, "kernel": 0.412109375}}, "feed_forward": {"intermediate_dense": {"bias": 0.02734375, "kernel": 0.4765625}, "output_dense": {"bias": 0.019775390625, "kernel": 0.380859375}}, "final_layer_norm": {"bias": 0.0458984375, "scale": 0.05615234375}, "layer_norm": {"bias": 0.0546875, "scale": 0.07275390625}}, "29": {"attention": {"k_proj": {"bias": 9.632110595703125e-05, "kernel": 0.2353515625}, "out_proj": {"bias": 0.0184326171875, "kernel": 0.373046875}, "q_proj": {"bias": 0.01165771484375, "kernel": 0.234375}, "v_proj": {"bias": 0.0302734375, "kernel": 0.376953125}}, "feed_forward": {"intermediate_dense": {"bias": 0.027099609375, "kernel": 0.53125}, "output_dense": {"bias": 0.01806640625, "kernel": 0.37109375}}, "final_layer_norm": {"bias": 0.0439453125, "scale": 0.041259765625}, "layer_norm": {"bias": 0.0556640625, "scale": 0.05029296875}}, "3": {"attention": {"k_proj": {"bias": 0.0002079010009765625, "kernel": 0.2421875}, "out_proj": {"bias": 0.03955078125, "kernel": 0.75}, "q_proj": {"bias": 0.016357421875, "kernel": 0.23046875}, "v_proj": {"bias": 0.0634765625, "kernel": 0.671875}}, "feed_forward": {"intermediate_dense": {"bias": 0.05419921875, "kernel": 0.90625}, "output_dense": {"bias": 0.0361328125, "kernel": 0.671875}}, "final_layer_norm": {"bias": 0.09228515625, "scale": 0.0947265625}, "layer_norm": {"bias": 0.0966796875, "scale": 0.171875}}, "30": {"attention": {"k_proj": {"bias": 0.00015354156494140625, "kernel": 0.244140625}, "out_proj": {"bias": 0.0177001953125, "kernel": 0.375}, "q_proj": {"bias": 0.0123291015625, "kernel": 0.2470703125}, "v_proj": {"bias": 0.02880859375, "kernel": 0.39453125}}, "feed_forward": {"intermediate_dense": {"bias": 0.0257568359375, "kernel": 0.5078125}, "output_dense": {"bias": 0.01708984375, "kernel": 0.326171875}}, "final_layer_norm": {"bias": 0.042236328125, "scale": 0.0517578125}, "layer_norm": {"bias": 0.044921875, "scale": 0.04443359375}}, "31": {"attention": {"k_proj": {"bias": 0.00011348724365234375, "kernel": 0.2138671875}, "out_proj": {"bias": 0.01611328125, "kernel": 0.30859375}, "q_proj": {"bias": 0.01141357421875, "kernel": 0.228515625}, "v_proj": {"bias": 0.025146484375, "kernel": 0.3359375}}, "feed_forward": {"intermediate_dense": {"bias": 0.02294921875, "kernel": 0.44921875}, "output_dense": {"bias": 0.0155029296875, "kernel": 0.28515625}}, "final_layer_norm": {"bias": 0.03662109375, "scale": 0.035400390625}, "layer_norm": {"bias": 0.04296875, "scale": 0.054443359375}}, "32": {"attention": {"k_proj": {"bias": 0.0001583099365234375, "kernel": 0.177734375}, "out_proj": {"bias": 0.01495361328125, "kernel": 0.302734375}, "q_proj": {"bias": 0.00897216796875, "kernel": 0.1748046875}, "v_proj": {"bias": 0.0234375, "kernel": 0.326171875}}, "feed_forward": {"intermediate_dense": {"bias": 0.0216064453125, "kernel": 0.4296875}, "output_dense": {"bias": 0.013916015625, "kernel": 0.267578125}}, "final_layer_norm": {"bias": 0.036376953125, "scale": 0.03125}, "layer_norm": {"bias": 0.038330078125, "scale": 0.064453125}}, "33": {"attention": {"k_proj": {"bias": 0.00018310546875, "kernel": 0.1875}, "out_proj": {"bias": 0.0133056640625, "kernel": 0.294921875}, "q_proj": {"bias": 0.0087890625, "kernel": 0.18359375}, "v_proj": {"bias": 0.021728515625, "kernel": 0.318359375}}, "feed_forward": {"intermediate_dense": {"bias": 0.020263671875, "kernel": 0.392578125}, "output_dense": {"bias": 0.0126953125, "kernel": 0.2578125}}, "final_layer_norm": {"bias": 0.036865234375, "scale": 0.041015625}, "layer_norm": {"bias": 0.033447265625, "scale": 0.0361328125}}, "34": {"attention": {"k_proj": {"bias": 0.00014400482177734375, "kernel": 0.203125}, "out_proj": {"bias": 0.01165771484375, "kernel": 0.275390625}, "q_proj": {"bias": 0.009033203125, "kernel": 0.1943359375}, "v_proj": {"bias": 0.01953125, "kernel": 0.2890625}}, "feed_forward": {"intermediate_dense": {"bias": 0.0184326171875, "kernel": 0.357421875}, "output_dense": {"bias": 0.01129150390625, "kernel": 0.248046875}}, "final_layer_norm": {"bias": 0.032958984375, "scale": 0.045654296875}, "layer_norm": {"bias": 0.03369140625, "scale": 0.0341796875}}, "35": {"attention": {"k_proj": {"bias": 0.0001678466796875, "kernel": 0.2041015625}, "out_proj": {"bias": 0.0098876953125, "kernel": 0.28515625}, "q_proj": {"bias": 0.009765625, "kernel": 0.2060546875}, "v_proj": {"bias": 0.01483154296875, "kernel": 0.255859375}}, "feed_forward": {"intermediate_dense": {"bias": 0.0150146484375, "kernel": 0.279296875}, "output_dense": {"bias": 0.0098876953125, "kernel": 0.2080078125}}, "final_layer_norm": {"bias": 0.02734375, "scale": 0.0301513671875}, "layer_norm": {"bias": 0.03466796875, "scale": 0.04345703125}}, "36": {"attention": {"k_proj": {"bias": 0.0001220703125, "kernel": 0.1455078125}, "out_proj": {"bias": 0.00909423828125, "kernel": 0.2177734375}, "q_proj": {"bias": 0.007049560546875, "kernel": 0.140625}, "v_proj": {"bias": 0.013427734375, "kernel": 0.2021484375}}, "feed_forward": {"intermediate_dense": {"bias": 0.01251220703125, "kernel": 0.234375}, "output_dense": {"bias": 0.0089111328125, "kernel": 0.1845703125}}, "final_layer_norm": {"bias": 0.021484375, "scale": 0.0230712890625}, "layer_norm": {"bias": 0.0277099609375, "scale": 0.0267333984375}}, "37": {"attention": {"k_proj": {"bias": 0.0001087188720703125, "kernel": 0.158203125}, "out_proj": {"bias": 0.0081787109375, "kernel": 0.212890625}, "q_proj": {"bias": 0.007659912109375, "kernel": 0.150390625}, "v_proj": {"bias": 0.01300048828125, "kernel": 0.197265625}}, "feed_forward": {"intermediate_dense": {"bias": 0.01141357421875, "kernel": 0.21484375}, "output_dense": {"bias": 0.00799560546875, "kernel": 0.16796875}}, "final_layer_norm": {"bias": 0.020263671875, "scale": 0.01806640625}, "layer_norm": {"bias": 0.0283203125, "scale": 0.044677734375}}, "38": {"attention": {"k_proj": {"bias": 0.0001354217529296875, "kernel": 0.1357421875}, "out_proj": {"bias": 0.007354736328125, "kernel": 0.1845703125}, "q_proj": {"bias": 0.0059814453125, "kernel": 0.12890625}, "v_proj": {"bias": 0.011474609375, "kernel": 0.1796875}}, "feed_forward": {"intermediate_dense": {"bias": 0.01025390625, "kernel": 0.193359375}, "output_dense": {"bias": 0.0072021484375, "kernel": 0.154296875}}, "final_layer_norm": {"bias": 0.0189208984375, "scale": 0.01513671875}, "layer_norm": {"bias": 0.024658203125, "scale": 0.0286865234375}}, "39": {"attention": {"k_proj": {"bias": 7.677078247070312e-05, "kernel": 0.1259765625}, "out_proj": {"bias": 0.00640869140625, "kernel": 0.1708984375}, "q_proj": {"bias": 0.00567626953125, "kernel": 0.12109375}, "v_proj": {"bias": 0.01031494140625, "kernel": 0.169921875}}, "feed_forward": {"intermediate_dense": {"bias": 0.00860595703125, "kernel": 0.169921875}, "output_dense": {"bias": 0.006195068359375, "kernel": 0.150390625}}, "final_layer_norm": {"bias": 0.01611328125, "scale": 0.01422119140625}, "layer_norm": {"bias": 0.0240478515625, "scale": 0.0244140625}}, "4": {"attention": {"k_proj": {"bias": 0.0002346038818359375, "kernel": 0.283203125}, "out_proj": {"bias": 0.03759765625, "kernel": 0.796875}, "q_proj": {"bias": 0.0179443359375, "kernel": 0.279296875}, "v_proj": {"bias": 0.058349609375, "kernel": 0.734375}}, "feed_forward": {"intermediate_dense": {"bias": 0.05078125, "kernel": 0.8203125}, "output_dense": {"bias": 0.034423828125, "kernel": 0.62890625}}, "final_layer_norm": {"bias": 0.0888671875, "scale": 0.09814453125}, "layer_norm": {"bias": 0.0908203125, "scale": 0.08544921875}}, "40": {"attention": {"k_proj": {"bias": 4.482269287109375e-05, "kernel": 0.1181640625}, "out_proj": {"bias": 0.006072998046875, "kernel": 0.138671875}, "q_proj": {"bias": 0.004791259765625, "kernel": 0.09765625}, "v_proj": {"bias": 0.00885009765625, "kernel": 0.1328125}}, "feed_forward": {"intermediate_dense": {"bias": 0.0081787109375, "kernel": 0.1513671875}, "output_dense": {"bias": 0.005859375, "kernel": 0.1298828125}}, "final_layer_norm": {"bias": 0.016845703125, "scale": 0.01806640625}, "layer_norm": {"bias": 0.018310546875, "scale": 0.0198974609375}}, "41": {"attention": {"k_proj": {"bias": 6.246566772460938e-05, "kernel": 0.11181640625}, "out_proj": {"bias": 0.00543212890625, "kernel": 0.1640625}, "q_proj": {"bias": 0.004302978515625, "kernel": 0.1015625}, "v_proj": {"bias": 0.0091552734375, "kernel": 0.173828125}}, "feed_forward": {"intermediate_dense": {"bias": 0.006805419921875, "kernel": 0.1376953125}, "output_dense": {"bias": 0.00531005859375, "kernel": 0.119140625}}, "final_layer_norm": {"bias": 0.0142822265625, "scale": 0.01708984375}, "layer_norm": {"bias": 0.019287109375, "scale": 0.0267333984375}}, "42": {"attention": {"k_proj": {"bias": 2.0384788513183594e-05, "kernel": 0.0400390625}, "out_proj": {"bias": 0.00537109375, "kernel": 0.0908203125}, "q_proj": {"bias": 0.001922607421875, "kernel": 0.039306640625}, "v_proj": {"bias": 0.006591796875, "kernel": 0.08642578125}}, "feed_forward": {"intermediate_dense": {"bias": 0.00604248046875, "kernel": 0.1259765625}, "output_dense": {"bias": 0.00537109375, "kernel": 0.1015625}}, "final_layer_norm": {"bias": 0.01263427734375, "scale": 0.014892578125}, "layer_norm": {"bias": 0.01220703125, "scale": 0.02099609375}}, "43": {"attention": {"k_proj": {"bias": 6.854534149169922e-06, "kernel": 0.019287109375}, "out_proj": {"bias": 0.00543212890625, "kernel": 0.0634765625}, "q_proj": {"bias": 0.0009918212890625, "kernel": 0.01953125}, "v_proj": {"bias": 0.005859375, "kernel": 0.0625}}, "feed_forward": {"intermediate_dense": {"bias": 0.005950927734375, "kernel": 0.12890625}, "output_dense": {"bias": 0.005523681640625, "kernel": 0.09912109375}}, "final_layer_norm": {"bias": 0.011474609375, "scale": 0.01495361328125}, "layer_norm": {"bias": 0.0107421875, "scale": 0.0145263671875}}, "44": {"attention": {"k_proj": {"bias": 6.377696990966797e-06, "kernel": 0.02197265625}, "out_proj": {"bias": 0.005615234375, "kernel": 0.068359375}, "q_proj": {"bias": 0.00116729736328125, "kernel": 0.0225830078125}, "v_proj": {"bias": 0.006256103515625, "kernel": 0.068359375}}, "feed_forward": {"intermediate_dense": {"bias": 0.005523681640625, "kernel": 0.130859375}, "output_dense": {"bias": 0.005767822265625, "kernel": 0.0947265625}}, "final_layer_norm": {"bias": 0.01025390625, "scale": 0.01220703125}, "layer_norm": {"bias": 0.01318359375, "scale": 0.01171875}}, "45": {"attention": {"k_proj": {"bias": 8.881092071533203e-06, "kernel": 0.021240234375}, "out_proj": {"bias": 0.00579833984375, "kernel": 0.06884765625}, "q_proj": {"bias": 0.001190185546875, "kernel": 0.020751953125}, "v_proj": {"bias": 0.007049560546875, "kernel": 0.0732421875}}, "feed_forward": {"intermediate_dense": {"bias": 0.00543212890625, "kernel": 0.126953125}, "output_dense": {"bias": 0.00592041015625, "kernel": 0.0927734375}}, "final_layer_norm": {"bias": 0.01025390625, "scale": 0.013916015625}, "layer_norm": {"bias": 0.0179443359375, "scale": 0.021484375}}, "46": {"attention": {"k_proj": {"bias": 1.049041748046875e-05, "kernel": 0.02490234375}, "out_proj": {"bias": 0.0059814453125, "kernel": 0.0693359375}, "q_proj": {"bias": 0.0014495849609375, "kernel": 0.0230712890625}, "v_proj": {"bias": 0.00830078125, "kernel": 0.0849609375}}, "feed_forward": {"intermediate_dense": {"bias": 0.004913330078125, "kernel": 0.103515625}, "output_dense": {"bias": 0.00604248046875, "kernel": 0.0830078125}}, "final_layer_norm": {"bias": 0.0118408203125, "scale": 0.02001953125}, "layer_norm": {"bias": 0.02294921875, "scale": 0.01806640625}}, "47": {"attention": {"k_proj": {"bias": 2.2411346435546875e-05, "kernel": 0.027099609375}, "out_proj": {"bias": 0.006103515625, "kernel": 0.0615234375}, "q_proj": {"bias": 0.00177001953125, "kernel": 0.02294921875}, "v_proj": {"bias": 0.010009765625, "kernel": 0.09375}}, "feed_forward": {"intermediate_dense": {"bias": 0.0045166015625, "kernel": 0.076171875}, "output_dense": {"bias": 0.0062255859375, "kernel": 0.07470703125}}, "final_layer_norm": {"bias": 0.0125732421875, "scale": 0.019775390625}, "layer_norm": {"bias": 0.032470703125, "scale": 0.0284423828125}}, "5": {"attention": {"k_proj": {"bias": 0.000118255615234375, "kernel": 0.251953125}, "out_proj": {"bias": 0.03662109375, "kernel": 0.515625}, "q_proj": {"bias": 0.0137939453125, "kernel": 0.2470703125}, "v_proj": {"bias": 0.056640625, "kernel": 0.52734375}}, "feed_forward": {"intermediate_dense": {"bias": 0.0498046875, "kernel": 0.7890625}, "output_dense": {"bias": 0.03466796875, "kernel": 0.6015625}}, "final_layer_norm": {"bias": 0.087890625, "scale": 0.07373046875}, "layer_norm": {"bias": 0.0888671875, "scale": 0.10546875}}, "6": {"attention": {"k_proj": {"bias": 0.00016021728515625, "kernel": 0.296875}, "out_proj": {"bias": 0.03466796875, "kernel": 0.6953125}, "q_proj": {"bias": 0.0169677734375, "kernel": 0.28515625}, "v_proj": {"bias": 0.05908203125, "kernel": 0.734375}}, "feed_forward": {"intermediate_dense": {"bias": 0.0478515625, "kernel": 0.7890625}, "output_dense": {"bias": 0.032958984375, "kernel": 0.57421875}}, "final_layer_norm": {"bias": 0.0849609375, "scale": 0.08447265625}, "layer_norm": {"bias": 0.0927734375, "scale": 0.11865234375}}, "7": {"attention": {"k_proj": {"bias": 0.0002079010009765625, "kernel": 0.30078125}, "out_proj": {"bias": 0.03271484375, "kernel": 0.6796875}, "q_proj": {"bias": 0.018798828125, "kernel": 0.29296875}, "v_proj": {"bias": 0.05419921875, "kernel": 0.7265625}}, "feed_forward": {"intermediate_dense": {"bias": 0.048583984375, "kernel": 0.8125}, "output_dense": {"bias": 0.03125, "kernel": 0.5859375}}, "final_layer_norm": {"bias": 0.08984375, "scale": 0.09033203125}, "layer_norm": {"bias": 0.0966796875, "scale": 0.107421875}}, "8": {"attention": {"k_proj": {"bias": 0.0001468658447265625, "kernel": 0.27734375}, "out_proj": {"bias": 0.03076171875, "kernel": 0.58984375}, "q_proj": {"bias": 0.016357421875, "kernel": 0.275390625}, "v_proj": {"bias": 0.052001953125, "kernel": 0.62109375}}, "feed_forward": {"intermediate_dense": {"bias": 0.046630859375, "kernel": 0.7890625}, "output_dense": {"bias": 0.0294189453125, "kernel": 0.58984375}}, "final_layer_norm": {"bias": 0.080078125, "scale": 0.0830078125}, "layer_norm": {"bias": 0.08642578125, "scale": 0.12451171875}}, "9": {"attention": {"k_proj": {"bias": 0.00026702880859375, "kernel": 0.369140625}, "out_proj": {"bias": 0.0269775390625, "kernel": 0.77734375}, "q_proj": {"bias": 0.01904296875, "kernel": 0.36328125}, "v_proj": {"bias": 0.044921875, "kernel": 0.78125}}, "feed_forward": {"intermediate_dense": {"bias": 0.0390625, "kernel": 0.671875}, "output_dense": {"bias": 0.0263671875, "kernel": 0.53125}}, "final_layer_norm": {"bias": 0.06884765625, "scale": 0.068359375}, "layer_norm": {"bias": 0.08349609375, "scale": 0.0927734375}}}, "pos_conv_embed": {"conv": {"bias": 0.1298828125, "weight_g": 0.1240234375, "weight_v": 0.953125}}}, "feature_extractor": {"conv_layers": {"0": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "1": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "2": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "3": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "4": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "5": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}, "6": {"conv": {"bias": 0.0, "kernel": 0.0}, "layer_norm": {"bias": 0.0, "scale": 0.0}}}}, "feature_projection": {"layer_norm": {"bias": 0.34765625, "scale": 0.45703125}, "projection": {"bias": 0.169921875, "kernel": 2.96875}}, "masked_spec_embed": 0.0}}, "layer_param_norm/": {"lm_head": {"bias": 0.0668909028172493, "kernel": 5.564761161804199}, "wav2vec2": {"encoder": {"layer_norm": {"bias": 0.9411773681640625, "scale": 22.735441207885742}, "layers": {"0": {"attention": {"k_proj": {"bias": 0.1572660654783249, "kernel": 26.265689849853516}, "out_proj": {"bias": 1.540636420249939, "kernel": 25.332489013671875}, "q_proj": {"bias": 1.345062494277954, "kernel": 26.540752410888672}, "v_proj": {"bias": 0.34915056824684143, "kernel": 25.96399688720703}}, "feed_forward": {"intermediate_dense": {"bias": 1.7922512292861938, "kernel": 96.14887237548828}, "output_dense": {"bias": 1.0332554578781128, "kernel": 91.98905181884766}}, "final_layer_norm": {"bias": 1.3027501106262207, "scale": 19.902679443359375}, "layer_norm": {"bias": 3.2321395874023438, "scale": 15.86198616027832}}, "1": {"attention": {"k_proj": {"bias": 0.14874659478664398, "kernel": 39.97101593017578}, "out_proj": {"bias": 1.3095301389694214, "kernel": 41.88142776489258}, "q_proj": {"bias": 2.8878695964813232, "kernel": 39.850341796875}, "v_proj": {"bias": 0.2854360044002533, "kernel": 40.33198547363281}}, "feed_forward": {"intermediate_dense": {"bias": 1.6540393829345703, "kernel": 94.21043395996094}, "output_dense": {"bias": 0.8036744594573975, "kernel": 84.77149963378906}}, "final_layer_norm": {"bias": 1.2088514566421509, "scale": 18.529014587402344}, "layer_norm": {"bias": 1.8125758171081543, "scale": 18.96085548400879}}, "10": {"attention": {"k_proj": {"bias": 0.1729704886674881, "kernel": 47.594642639160156}, "out_proj": {"bias": 1.2517130374908447, "kernel": 50.368370056152344}, "q_proj": {"bias": 2.4282305240631104, "kernel": 47.5654411315918}, "v_proj": {"bias": 0.31612780690193176, "kernel": 50.5357780456543}}, "feed_forward": {"intermediate_dense": {"bias": 1.6828261613845825, "kernel": 98.35198974609375}, "output_dense": {"bias": 0.560376763343811, "kernel": 92.19532775878906}}, "final_layer_norm": {"bias": 2.2686586380004883, "scale": 20.38600730895996}, "layer_norm": {"bias": 1.7384451627731323, "scale": 22.122407913208008}}, "11": {"attention": {"k_proj": {"bias": 0.20578038692474365, "kernel": 47.44072723388672}, "out_proj": {"bias": 1.1011972427368164, "kernel": 49.64116287231445}, "q_proj": {"bias": 2.4754388332366943, "kernel": 47.175193786621094}, "v_proj": {"bias": 0.37477394938468933, "kernel": 50.14057922363281}}, "feed_forward": {"intermediate_dense": {"bias": 1.7321875095367432, "kernel": 99.11766052246094}, "output_dense": {"bias": 0.5437204837799072, "kernel": 93.87814331054688}}, "final_layer_norm": {"bias": 2.2578558921813965, "scale": 20.41146469116211}, "layer_norm": {"bias": 1.7143844366073608, "scale": 22.52808952331543}}, "12": {"attention": {"k_proj": {"bias": 0.1765490174293518, "kernel": 48.007171630859375}, "out_proj": {"bias": 1.0823392868041992, "kernel": 49.907554626464844}, "q_proj": {"bias": 2.351224899291992, "kernel": 47.764991760253906}, "v_proj": {"bias": 0.356992244720459, "kernel": 50.29539489746094}}, "feed_forward": {"intermediate_dense": {"bias": 1.7750531435012817, "kernel": 99.95246887207031}, "output_dense": {"bias": 0.5311572551727295, "kernel": 95.42961120605469}}, "final_layer_norm": {"bias": 2.2109756469726562, "scale": 20.368785858154297}, "layer_norm": {"bias": 1.7906843423843384, "scale": 23.005416870117188}}, "13": {"attention": {"k_proj": {"bias": 0.18811266124248505, "kernel": 49.8721923828125}, "out_proj": {"bias": 1.076762318611145, "kernel": 49.606414794921875}, "q_proj": {"bias": 2.322277307510376, "kernel": 49.74319076538086}, "v_proj": {"bias": 0.3691728115081787, "kernel": 49.76460647583008}}, "feed_forward": {"intermediate_dense": {"bias": 1.829641342163086, "kernel": 100.58248138427734}, "output_dense": {"bias": 0.5479011535644531, "kernel": 95.87431335449219}}, "final_layer_norm": {"bias": 2.109241485595703, "scale": 20.520580291748047}, "layer_norm": {"bias": 1.9025626182556152, "scale": 23.28399085998535}}, "14": {"attention": {"k_proj": {"bias": 0.22546514868736267, "kernel": 50.01702117919922}, "out_proj": {"bias": 1.2384552955627441, "kernel": 47.94068908691406}, "q_proj": {"bias": 2.382194995880127, "kernel": 50.0910758972168}, "v_proj": {"bias": 0.35839876532554626, "kernel": 47.508934020996094}}, "feed_forward": {"intermediate_dense": {"bias": 1.86549711227417, "kernel": 101.19500732421875}, "output_dense": {"bias": 0.5690709948539734, "kernel": 97.23810577392578}}, "final_layer_norm": {"bias": 2.2292520999908447, "scale": 20.66695213317871}, "layer_norm": {"bias": 2.0422415733337402, "scale": 23.39774513244629}}, "15": {"attention": {"k_proj": {"bias": 0.18015369772911072, "kernel": 50.14893341064453}, "out_proj": {"bias": 1.2943248748779297, "kernel": 48.63526916503906}, "q_proj": {"bias": 2.523224353790283, "kernel": 50.23434829711914}, "v_proj": {"bias": 0.4106982350349426, "kernel": 48.305335998535156}}, "feed_forward": {"intermediate_dense": {"bias": 1.8729243278503418, "kernel": 101.01643371582031}, "output_dense": {"bias": 0.7231183052062988, "kernel": 97.94534301757812}}, "final_layer_norm": {"bias": 2.1691672801971436, "scale": 20.768878936767578}, "layer_norm": {"bias": 2.275395154953003, "scale": 23.74152374267578}}, "16": {"attention": {"k_proj": {"bias": 0.1596199870109558, "kernel": 50.00080871582031}, "out_proj": {"bias": 1.217071771621704, "kernel": 48.02516555786133}, "q_proj": {"bias": 2.6123976707458496, "kernel": 49.893272399902344}, "v_proj": {"bias": 0.3556521236896515, "kernel": 47.71862030029297}}, "feed_forward": {"intermediate_dense": {"bias": 1.8646655082702637, "kernel": 101.64906311035156}, "output_dense": {"bias": 0.745720386505127, "kernel": 98.78334045410156}}, "final_layer_norm": {"bias": 2.2145400047302246, "scale": 21.247108459472656}, "layer_norm": {"bias": 2.196000814437866, "scale": 22.552139282226562}}, "17": {"attention": {"k_proj": {"bias": 0.15246188640594482, "kernel": 50.159278869628906}, "out_proj": {"bias": 1.1550219058990479, "kernel": 47.34062194824219}, "q_proj": {"bias": 2.700528144836426, "kernel": 50.237125396728516}, "v_proj": {"bias": 0.3945621848106384, "kernel": 47.01970672607422}}, "feed_forward": {"intermediate_dense": {"bias": 1.875906229019165, "kernel": 102.71083068847656}, "output_dense": {"bias": 0.763410210609436, "kernel": 99.16596984863281}}, "final_layer_norm": {"bias": 2.297891616821289, "scale": 21.794700622558594}, "layer_norm": {"bias": 2.0964274406433105, "scale": 21.963212966918945}}, "18": {"attention": {"k_proj": {"bias": 0.17540258169174194, "kernel": 50.5605583190918}, "out_proj": {"bias": 1.258694052696228, "kernel": 48.3692626953125}, "q_proj": {"bias": 2.578209400177002, "kernel": 50.937294006347656}, "v_proj": {"bias": 0.42091283202171326, "kernel": 47.91288757324219}}, "feed_forward": {"intermediate_dense": {"bias": 1.912561297416687, "kernel": 103.00012969970703}, "output_dense": {"bias": 0.8791664838790894, "kernel": 100.81649780273438}}, "final_layer_norm": {"bias": 2.39349102973938, "scale": 21.778369903564453}, "layer_norm": {"bias": 2.2818918228149414, "scale": 23.874561309814453}}, "19": {"attention": {"k_proj": {"bias": 0.14555136859416962, "kernel": 49.727783203125}, "out_proj": {"bias": 1.2327790260314941, "kernel": 48.223941802978516}, "q_proj": {"bias": 2.850454330444336, "kernel": 50.13501739501953}, "v_proj": {"bias": 0.3850022852420807, "kernel": 47.482051849365234}}, "feed_forward": {"intermediate_dense": {"bias": 1.962997555732727, "kernel": 103.55958557128906}, "output_dense": {"bias": 0.9461472034454346, "kernel": 101.71952819824219}}, "final_layer_norm": {"bias": 2.3411974906921387, "scale": 22.13956069946289}, "layer_norm": {"bias": 2.1903624534606934, "scale": 22.96629524230957}}, "2": {"attention": {"k_proj": {"bias": 0.16273045539855957, "kernel": 46.281925201416016}, "out_proj": {"bias": 1.224226713180542, "kernel": 44.15675735473633}, "q_proj": {"bias": 3.0906779766082764, "kernel": 46.02598571777344}, "v_proj": {"bias": 0.3187757134437561, "kernel": 44.112770080566406}}, "feed_forward": {"intermediate_dense": {"bias": 1.6842701435089111, "kernel": 99.14990234375}, "output_dense": {"bias": 0.6949976086616516, "kernel": 87.93550872802734}}, "final_layer_norm": {"bias": 1.4942196607589722, "scale": 21.074359893798828}, "layer_norm": {"bias": 1.6854312419891357, "scale": 21.76892852783203}}, "20": {"attention": {"k_proj": {"bias": 0.13836346566677094, "kernel": 49.7047004699707}, "out_proj": {"bias": 1.2614378929138184, "kernel": 47.5956916809082}, "q_proj": {"bias": 2.763533115386963, "kernel": 50.45436096191406}, "v_proj": {"bias": 0.35550954937934875, "kernel": 46.5172004699707}}, "feed_forward": {"intermediate_dense": {"bias": 1.9570611715316772, "kernel": 104.79476165771484}, "output_dense": {"bias": 1.0612233877182007, "kernel": 102.35823822021484}}, "final_layer_norm": {"bias": 2.3444409370422363, "scale": 23.082225799560547}, "layer_norm": {"bias": 2.165036201477051, "scale": 23.06981086730957}}, "21": {"attention": {"k_proj": {"bias": 0.15504375100135803, "kernel": 50.163543701171875}, "out_proj": {"bias": 1.3001418113708496, "kernel": 47.61540222167969}, "q_proj": {"bias": 2.7271313667297363, "kernel": 50.96747589111328}, "v_proj": {"bias": 0.4037085771560669, "kernel": 46.739601135253906}}, "feed_forward": {"intermediate_dense": {"bias": 1.9970109462738037, "kernel": 104.97686004638672}, "output_dense": {"bias": 1.140713095664978, "kernel": 102.67676544189453}}, "final_layer_norm": {"bias": 2.370567798614502, "scale": 22.737632751464844}, "layer_norm": {"bias": 2.2212581634521484, "scale": 23.29868507385254}}, "22": {"attention": {"k_proj": {"bias": 0.16382987797260284, "kernel": 50.56926727294922}, "out_proj": {"bias": 1.2247910499572754, "kernel": 47.1660041809082}, "q_proj": {"bias": 2.7855143547058105, "kernel": 50.92293930053711}, "v_proj": {"bias": 0.35865116119384766, "kernel": 47.0435791015625}}, "feed_forward": {"intermediate_dense": {"bias": 1.9342015981674194, "kernel": 105.3548583984375}, "output_dense": {"bias": 1.1557631492614746, "kernel": 101.9817886352539}}, "final_layer_norm": {"bias": 2.264819622039795, "scale": 22.244813919067383}, "layer_norm": {"bias": 2.229525566101074, "scale": 22.353832244873047}}, "23": {"attention": {"k_proj": {"bias": 0.21246662735939026, "kernel": 51.59439468383789}, "out_proj": {"bias": 1.364210844039917, "kernel": 48.22317886352539}, "q_proj": {"bias": 2.6520018577575684, "kernel": 51.69462585449219}, "v_proj": {"bias": 0.5232762098312378, "kernel": 48.86384582519531}}, "feed_forward": {"intermediate_dense": {"bias": 1.9091339111328125, "kernel": 105.17338562011719}, "output_dense": {"bias": 1.1327968835830688, "kernel": 102.7879638671875}}, "final_layer_norm": {"bias": 2.5350351333618164, "scale": 22.22771644592285}, "layer_norm": {"bias": 2.701894760131836, "scale": 23.687578201293945}}, "24": {"attention": {"k_proj": {"bias": 0.18005745112895966, "kernel": 50.28105926513672}, "out_proj": {"bias": 1.421328067779541, "kernel": 50.1175537109375}, "q_proj": {"bias": 2.7797858715057373, "kernel": 50.268226623535156}, "v_proj": {"bias": 0.481330931186676, "kernel": 50.19598388671875}}, "feed_forward": {"intermediate_dense": {"bias": 2.0205435752868652, "kernel": 104.63185119628906}, "output_dense": {"bias": 1.171934962272644, "kernel": 105.64089965820312}}, "final_layer_norm": {"bias": 2.61605167388916, "scale": 22.27845001220703}, "layer_norm": {"bias": 2.4518842697143555, "scale": 23.27135467529297}}, "25": {"attention": {"k_proj": {"bias": 0.1678708791732788, "kernel": 50.65785217285156}, "out_proj": {"bias": 1.2310528755187988, "kernel": 48.01655197143555}, "q_proj": {"bias": 2.8710856437683105, "kernel": 50.43537139892578}, "v_proj": {"bias": 0.5626442432403564, "kernel": 48.55765914916992}}, "feed_forward": {"intermediate_dense": {"bias": 1.9219928979873657, "kernel": 104.88790130615234}, "output_dense": {"bias": 1.0354797840118408, "kernel": 105.56619262695312}}, "final_layer_norm": {"bias": 2.343381881713867, "scale": 22.79658317565918}, "layer_norm": {"bias": 2.5976107120513916, "scale": 22.297609329223633}}, "26": {"attention": {"k_proj": {"bias": 0.18550752103328705, "kernel": 50.914405822753906}, "out_proj": {"bias": 1.1548503637313843, "kernel": 48.805328369140625}, "q_proj": {"bias": 2.851818561553955, "kernel": 50.67176818847656}, "v_proj": {"bias": 0.478745698928833, "kernel": 49.415489196777344}}, "feed_forward": {"intermediate_dense": {"bias": 2.024549961090088, "kernel": 104.31269836425781}, "output_dense": {"bias": 0.9997164607048035, "kernel": 102.69577026367188}}, "final_layer_norm": {"bias": 1.9872686862945557, "scale": 21.646724700927734}, "layer_norm": {"bias": 2.4832143783569336, "scale": 22.715808868408203}}, "27": {"attention": {"k_proj": {"bias": 0.41229814291000366, "kernel": 51.72254943847656}, "out_proj": {"bias": 1.3984119892120361, "kernel": 50.182186126708984}, "q_proj": {"bias": 2.616152048110962, "kernel": 51.59402847290039}, "v_proj": {"bias": 0.5837572813034058, "kernel": 50.64431381225586}}, "feed_forward": {"intermediate_dense": {"bias": 2.182586669921875, "kernel": 102.61892700195312}, "output_dense": {"bias": 0.8772892951965332, "kernel": 102.38434600830078}}, "final_layer_norm": {"bias": 2.2664685249328613, "scale": 20.938594818115234}, "layer_norm": {"bias": 2.558232307434082, "scale": 23.56865692138672}}, "28": {"attention": {"k_proj": {"bias": 0.4503825604915619, "kernel": 52.60496139526367}, "out_proj": {"bias": 1.418035864830017, "kernel": 50.870330810546875}, "q_proj": {"bias": 2.7903497219085693, "kernel": 52.25389862060547}, "v_proj": {"bias": 0.4630447328090668, "kernel": 51.217445373535156}}, "feed_forward": {"intermediate_dense": {"bias": 2.125363349914551, "kernel": 102.62171936035156}, "output_dense": {"bias": 0.7697524428367615, "kernel": 104.55252075195312}}, "final_layer_norm": {"bias": 2.181206703186035, "scale": 21.27033233642578}, "layer_norm": {"bias": 2.0507431030273438, "scale": 24.352882385253906}}, "29": {"attention": {"k_proj": {"bias": 0.17947150766849518, "kernel": 49.040016174316406}, "out_proj": {"bias": 1.3947250843048096, "kernel": 53.36112976074219}, "q_proj": {"bias": 2.733522415161133, "kernel": 48.861534118652344}, "v_proj": {"bias": 0.4220423698425293, "kernel": 53.27291488647461}}, "feed_forward": {"intermediate_dense": {"bias": 2.115380048751831, "kernel": 103.29501342773438}, "output_dense": {"bias": 0.8926779627799988, "kernel": 108.8265380859375}}, "final_layer_norm": {"bias": 2.3777034282684326, "scale": 22.42627716064453}, "layer_norm": {"bias": 2.17183256149292, "scale": 25.36501693725586}}, "3": {"attention": {"k_proj": {"bias": 0.20873260498046875, "kernel": 50.3943977355957}, "out_proj": {"bias": 1.3884034156799316, "kernel": 46.793052673339844}, "q_proj": {"bias": 2.741361618041992, "kernel": 50.61634063720703}, "v_proj": {"bias": 0.3030884861946106, "kernel": 47.16839599609375}}, "feed_forward": {"intermediate_dense": {"bias": 1.6966179609298706, "kernel": 100.75701141357422}, "output_dense": {"bias": 0.6564837694168091, "kernel": 90.81097412109375}}, "final_layer_norm": {"bias": 1.773003339767456, "scale": 21.166318893432617}, "layer_norm": {"bias": 1.8634135723114014, "scale": 23.474411010742188}}, "30": {"attention": {"k_proj": {"bias": 0.3080509305000305, "kernel": 50.9363899230957}, "out_proj": {"bias": 1.1957952976226807, "kernel": 49.6925163269043}, "q_proj": {"bias": 2.8136203289031982, "kernel": 51.00883102416992}, "v_proj": {"bias": 0.4762360453605652, "kernel": 50.050315856933594}}, "feed_forward": {"intermediate_dense": {"bias": 2.048686981201172, "kernel": 103.77497863769531}, "output_dense": {"bias": 0.8379020690917969, "kernel": 107.83690643310547}}, "final_layer_norm": {"bias": 2.2059359550476074, "scale": 23.55939483642578}, "layer_norm": {"bias": 2.2876853942871094, "scale": 25.086034774780273}}, "31": {"attention": {"k_proj": {"bias": 0.3784014582633972, "kernel": 49.39188003540039}, "out_proj": {"bias": 1.119596242904663, "kernel": 50.44926834106445}, "q_proj": {"bias": 2.5993165969848633, "kernel": 49.48455810546875}, "v_proj": {"bias": 0.522638738155365, "kernel": 50.58903884887695}}, "feed_forward": {"intermediate_dense": {"bias": 2.1306161880493164, "kernel": 102.43810272216797}, "output_dense": {"bias": 1.0122932195663452, "kernel": 105.18507385253906}}, "final_layer_norm": {"bias": 2.1017799377441406, "scale": 23.437347412109375}, "layer_norm": {"bias": 2.281771183013916, "scale": 24.799633026123047}}, "32": {"attention": {"k_proj": {"bias": 0.25974252820014954, "kernel": 48.31093215942383}, "out_proj": {"bias": 1.1245567798614502, "kernel": 49.66279602050781}, "q_proj": {"bias": 2.827341079711914, "kernel": 48.274757385253906}, "v_proj": {"bias": 0.39349639415740967, "kernel": 49.95775604248047}}, "feed_forward": {"intermediate_dense": {"bias": 2.0681939125061035, "kernel": 101.28511810302734}, "output_dense": {"bias": 1.0686500072479248, "kernel": 104.53141784667969}}, "final_layer_norm": {"bias": 2.0827746391296387, "scale": 23.84016227722168}, "layer_norm": {"bias": 2.262401580810547, "scale": 25.117748260498047}}, "33": {"attention": {"k_proj": {"bias": 0.26482468843460083, "kernel": 48.2266960144043}, "out_proj": {"bias": 1.1604676246643066, "kernel": 49.49713897705078}, "q_proj": {"bias": 2.972139835357666, "kernel": 48.19778060913086}, "v_proj": {"bias": 0.41569650173187256, "kernel": 49.77226257324219}}, "feed_forward": {"intermediate_dense": {"bias": 2.076352834701538, "kernel": 99.64717864990234}, "output_dense": {"bias": 1.0581281185150146, "kernel": 103.25572204589844}}, "final_layer_norm": {"bias": 1.994544506072998, "scale": 23.598308563232422}, "layer_norm": {"bias": 2.439363956451416, "scale": 25.341171264648438}}, "34": {"attention": {"k_proj": {"bias": 0.2755166292190552, "kernel": 47.42335510253906}, "out_proj": {"bias": 1.4115952253341675, "kernel": 50.949275970458984}, "q_proj": {"bias": 2.8844351768493652, "kernel": 47.45287322998047}, "v_proj": {"bias": 0.38429465889930725, "kernel": 50.88615417480469}}, "feed_forward": {"intermediate_dense": {"bias": 2.164264678955078, "kernel": 98.45794677734375}, "output_dense": {"bias": 0.9942412972450256, "kernel": 102.51951599121094}}, "final_layer_norm": {"bias": 1.9257347583770752, "scale": 23.230693817138672}, "layer_norm": {"bias": 2.5131359100341797, "scale": 25.724273681640625}}, "35": {"attention": {"k_proj": {"bias": 0.3880418837070465, "kernel": 49.17444610595703}, "out_proj": {"bias": 1.338890552520752, "kernel": 49.785396575927734}, "q_proj": {"bias": 2.5939016342163086, "kernel": 49.47700500488281}, "v_proj": {"bias": 0.4775533974170685, "kernel": 49.64286422729492}}, "feed_forward": {"intermediate_dense": {"bias": 2.2472829818725586, "kernel": 97.02595520019531}, "output_dense": {"bias": 0.8886467218399048, "kernel": 101.20555877685547}}, "final_layer_norm": {"bias": 1.9989571571350098, "scale": 23.346912384033203}, "layer_norm": {"bias": 2.3170738220214844, "scale": 26.273502349853516}}, "36": {"attention": {"k_proj": {"bias": 0.23954293131828308, "kernel": 46.442291259765625}, "out_proj": {"bias": 1.3719561100006104, "kernel": 51.094085693359375}, "q_proj": {"bias": 2.6886184215545654, "kernel": 46.40093994140625}, "v_proj": {"bias": 0.3625168800354004, "kernel": 51.2959098815918}}, "feed_forward": {"intermediate_dense": {"bias": 2.129176378250122, "kernel": 96.0584716796875}, "output_dense": {"bias": 0.9140334129333496, "kernel": 100.84213256835938}}, "final_layer_norm": {"bias": 1.6761391162872314, "scale": 23.85495376586914}, "layer_norm": {"bias": 2.0316162109375, "scale": 25.767738342285156}}, "37": {"attention": {"k_proj": {"bias": 0.5490036010742188, "kernel": 45.51734924316406}, "out_proj": {"bias": 1.6235731840133667, "kernel": 51.08349609375}, "q_proj": {"bias": 2.3980464935302734, "kernel": 45.56562805175781}, "v_proj": {"bias": 0.35822078585624695, "kernel": 50.9661979675293}}, "feed_forward": {"intermediate_dense": {"bias": 2.023285388946533, "kernel": 95.24488830566406}, "output_dense": {"bias": 0.9175524115562439, "kernel": 100.54832458496094}}, "final_layer_norm": {"bias": 1.4775093793869019, "scale": 24.243749618530273}, "layer_norm": {"bias": 1.9990556240081787, "scale": 25.78732681274414}}, "38": {"attention": {"k_proj": {"bias": 0.6277763843536377, "kernel": 43.699893951416016}, "out_proj": {"bias": 1.3179919719696045, "kernel": 50.56129837036133}, "q_proj": {"bias": 2.3240532875061035, "kernel": 43.691986083984375}, "v_proj": {"bias": 0.4102995991706848, "kernel": 50.447662353515625}}, "feed_forward": {"intermediate_dense": {"bias": 1.9623992443084717, "kernel": 93.24853515625}, "output_dense": {"bias": 0.9016005992889404, "kernel": 98.77943420410156}}, "final_layer_norm": {"bias": 1.512964129447937, "scale": 24.965984344482422}, "layer_norm": {"bias": 2.1805269718170166, "scale": 26.572105407714844}}, "39": {"attention": {"k_proj": {"bias": 0.6627014875411987, "kernel": 43.5059814453125}, "out_proj": {"bias": 1.6037161350250244, "kernel": 50.41535186767578}, "q_proj": {"bias": 2.112717628479004, "kernel": 43.87850570678711}, "v_proj": {"bias": 0.390166699886322, "kernel": 50.09760665893555}}, "feed_forward": {"intermediate_dense": {"bias": 1.9482412338256836, "kernel": 91.50100708007812}, "output_dense": {"bias": 0.9774587750434875, "kernel": 99.10246276855469}}, "final_layer_norm": {"bias": 1.6250338554382324, "scale": 25.600841522216797}, "layer_norm": {"bias": 2.1454267501831055, "scale": 27.16954803466797}}, "4": {"attention": {"k_proj": {"bias": 0.21773464977741241, "kernel": 52.96697998046875}, "out_proj": {"bias": 1.5798711776733398, "kernel": 48.24322509765625}, "q_proj": {"bias": 2.5302865505218506, "kernel": 53.159481048583984}, "v_proj": {"bias": 0.3523319363594055, "kernel": 48.58223342895508}}, "feed_forward": {"intermediate_dense": {"bias": 1.6854965686798096, "kernel": 100.35298156738281}, "output_dense": {"bias": 0.8151395320892334, "kernel": 92.05111694335938}}, "final_layer_norm": {"bias": 1.8813955783843994, "scale": 20.698898315429688}, "layer_norm": {"bias": 1.9626739025115967, "scale": 23.951351165771484}}, "40": {"attention": {"k_proj": {"bias": 0.5965253114700317, "kernel": 42.784297943115234}, "out_proj": {"bias": 1.54435133934021, "kernel": 49.03851318359375}, "q_proj": {"bias": 2.0332603454589844, "kernel": 43.529693603515625}, "v_proj": {"bias": 0.44385749101638794, "kernel": 48.619354248046875}}, "feed_forward": {"intermediate_dense": {"bias": 1.8044453859329224, "kernel": 89.7060317993164}, "output_dense": {"bias": 1.0290844440460205, "kernel": 96.339111328125}}, "final_layer_norm": {"bias": 1.7928125858306885, "scale": 24.868694305419922}, "layer_norm": {"bias": 2.09533953666687, "scale": 26.69829750061035}}, "41": {"attention": {"k_proj": {"bias": 1.676910638809204, "kernel": 40.18012237548828}, "out_proj": {"bias": 1.3033416271209717, "kernel": 50.59858322143555}, "q_proj": {"bias": 1.720700979232788, "kernel": 40.91082000732422}, "v_proj": {"bias": 0.3962639570236206, "kernel": 49.54127502441406}}, "feed_forward": {"intermediate_dense": {"bias": 1.9326260089874268, "kernel": 86.49552154541016}, "output_dense": {"bias": 1.0626506805419922, "kernel": 95.35970306396484}}, "final_layer_norm": {"bias": 2.287132740020752, "scale": 28.329570770263672}, "layer_norm": {"bias": 2.0972471237182617, "scale": 28.477943420410156}}, "42": {"attention": {"k_proj": {"bias": 0.8007075786590576, "kernel": 36.84203338623047}, "out_proj": {"bias": 1.3564722537994385, "kernel": 44.805450439453125}, "q_proj": {"bias": 1.545430302619934, "kernel": 38.173866271972656}, "v_proj": {"bias": 0.6166673898696899, "kernel": 43.14790344238281}}, "feed_forward": {"intermediate_dense": {"bias": 1.6966503858566284, "kernel": 85.47465515136719}, "output_dense": {"bias": 1.1137490272521973, "kernel": 93.52703857421875}}, "final_layer_norm": {"bias": 1.9977667331695557, "scale": 29.619346618652344}, "layer_norm": {"bias": 1.5640318393707275, "scale": 27.318359375}}, "43": {"attention": {"k_proj": {"bias": 1.2139406204223633, "kernel": 33.34306335449219}, "out_proj": {"bias": 1.3445727825164795, "kernel": 41.18675231933594}, "q_proj": {"bias": 1.3703632354736328, "kernel": 34.14249038696289}, "v_proj": {"bias": 0.5302899479866028, "kernel": 39.08563995361328}}, "feed_forward": {"intermediate_dense": {"bias": 1.7241705656051636, "kernel": 84.66365051269531}, "output_dense": {"bias": 0.875501275062561, "kernel": 91.44625091552734}}, "final_layer_norm": {"bias": 1.9585936069488525, "scale": 31.831100463867188}, "layer_norm": {"bias": 1.6979660987854004, "scale": 25.48971939086914}}, "44": {"attention": {"k_proj": {"bias": 2.492720365524292, "kernel": 33.942161560058594}, "out_proj": {"bias": 1.106978416442871, "kernel": 44.9072265625}, "q_proj": {"bias": 1.3083606958389282, "kernel": 34.30083084106445}, "v_proj": {"bias": 0.38603541254997253, "kernel": 44.0037841796875}}, "feed_forward": {"intermediate_dense": {"bias": 1.799367070198059, "kernel": 83.60717010498047}, "output_dense": {"bias": 0.8168225288391113, "kernel": 89.09458923339844}}, "final_layer_norm": {"bias": 1.9565825462341309, "scale": 34.001800537109375}, "layer_norm": {"bias": 1.5935062170028687, "scale": 25.508739471435547}}, "45": {"attention": {"k_proj": {"bias": 2.0522377490997314, "kernel": 33.76741027832031}, "out_proj": {"bias": 0.9882014989852905, "kernel": 48.50224304199219}, "q_proj": {"bias": 1.3905718326568604, "kernel": 33.92951965332031}, "v_proj": {"bias": 0.4323008954524994, "kernel": 48.66568374633789}}, "feed_forward": {"intermediate_dense": {"bias": 1.914729356765747, "kernel": 80.2913589477539}, "output_dense": {"bias": 0.9462040662765503, "kernel": 84.49946594238281}}, "final_layer_norm": {"bias": 1.6955138444900513, "scale": 32.71482467651367}, "layer_norm": {"bias": 1.5199291706085205, "scale": 24.034631729125977}}, "46": {"attention": {"k_proj": {"bias": 1.5418241024017334, "kernel": 34.94271469116211}, "out_proj": {"bias": 0.7529411315917969, "kernel": 50.925716400146484}, "q_proj": {"bias": 1.5559134483337402, "kernel": 35.040340423583984}, "v_proj": {"bias": 0.3770799934864044, "kernel": 51.6875}}, "feed_forward": {"intermediate_dense": {"bias": 1.9622228145599365, "kernel": 74.68043518066406}, "output_dense": {"bias": 1.1015828847885132, "kernel": 74.83889770507812}}, "final_layer_norm": {"bias": 1.6539316177368164, "scale": 28.228740692138672}, "layer_norm": {"bias": 1.3381633758544922, "scale": 22.947879791259766}}, "47": {"attention": {"k_proj": {"bias": 0.2917669415473938, "kernel": 37.277713775634766}, "out_proj": {"bias": 0.6384203433990479, "kernel": 45.20857620239258}, "q_proj": {"bias": 1.6836624145507812, "kernel": 37.879981994628906}, "v_proj": {"bias": 0.3526262640953064, "kernel": 46.196876525878906}}, "feed_forward": {"intermediate_dense": {"bias": 2.025301933288574, "kernel": 72.08566284179688}, "output_dense": {"bias": 0.6106955409049988, "kernel": 68.43527221679688}}, "final_layer_norm": {"bias": 1.548535704612732, "scale": 23.066944122314453}, "layer_norm": {"bias": 1.059852123260498, "scale": 20.202987670898438}}, "5": {"attention": {"k_proj": {"bias": 0.16882865130901337, "kernel": 48.27493667602539}, "out_proj": {"bias": 1.5571939945220947, "kernel": 49.373023986816406}, "q_proj": {"bias": 2.630668878555298, "kernel": 48.41212463378906}, "v_proj": {"bias": 0.31896352767944336, "kernel": 50.12682342529297}}, "feed_forward": {"intermediate_dense": {"bias": 1.6141834259033203, "kernel": 100.44667053222656}, "output_dense": {"bias": 0.8445441722869873, "kernel": 91.35646057128906}}, "final_layer_norm": {"bias": 2.1409213542938232, "scale": 20.886255264282227}, "layer_norm": {"bias": 2.0031042098999023, "scale": 23.135677337646484}}, "6": {"attention": {"k_proj": {"bias": 0.2727879285812378, "kernel": 50.084373474121094}, "out_proj": {"bias": 1.5454241037368774, "kernel": 48.77831268310547}, "q_proj": {"bias": 2.6858837604522705, "kernel": 50.5609130859375}, "v_proj": {"bias": 0.3207029104232788, "kernel": 49.28154754638672}}, "feed_forward": {"intermediate_dense": {"bias": 1.5875343084335327, "kernel": 99.54606628417969}, "output_dense": {"bias": 0.7026320099830627, "kernel": 90.97244262695312}}, "final_layer_norm": {"bias": 2.438199758529663, "scale": 20.364471435546875}, "layer_norm": {"bias": 1.9867676496505737, "scale": 23.706497192382812}}, "7": {"attention": {"k_proj": {"bias": 0.2606049180030823, "kernel": 49.72987365722656}, "out_proj": {"bias": 1.3674311637878418, "kernel": 49.03389358520508}, "q_proj": {"bias": 2.4531450271606445, "kernel": 50.121490478515625}, "v_proj": {"bias": 0.4057798385620117, "kernel": 48.98875427246094}}, "feed_forward": {"intermediate_dense": {"bias": 1.5941696166992188, "kernel": 99.27568054199219}, "output_dense": {"bias": 0.5401815176010132, "kernel": 90.70722961425781}}, "final_layer_norm": {"bias": 2.292677879333496, "scale": 20.583412170410156}, "layer_norm": {"bias": 1.9037420749664307, "scale": 22.42470932006836}}, "8": {"attention": {"k_proj": {"bias": 0.24493072926998138, "kernel": 49.249691009521484}, "out_proj": {"bias": 1.188157558441162, "kernel": 49.54103088378906}, "q_proj": {"bias": 2.4314587116241455, "kernel": 49.03004837036133}, "v_proj": {"bias": 0.33289897441864014, "kernel": 49.696128845214844}}, "feed_forward": {"intermediate_dense": {"bias": 1.6514670848846436, "kernel": 98.86155700683594}, "output_dense": {"bias": 0.5011764168739319, "kernel": 90.14228057861328}}, "final_layer_norm": {"bias": 2.2437543869018555, "scale": 20.356639862060547}, "layer_norm": {"bias": 1.839092493057251, "scale": 22.84292221069336}}, "9": {"attention": {"k_proj": {"bias": 0.2741130292415619, "kernel": 50.02962112426758}, "out_proj": {"bias": 1.4082090854644775, "kernel": 50.41935348510742}, "q_proj": {"bias": 2.3752646446228027, "kernel": 50.20918655395508}, "v_proj": {"bias": 0.35025060176849365, "kernel": 50.80641174316406}}, "feed_forward": {"intermediate_dense": {"bias": 1.730910062789917, "kernel": 97.50648498535156}, "output_dense": {"bias": 0.6366411447525024, "kernel": 90.64373779296875}}, "final_layer_norm": {"bias": 2.162881851196289, "scale": 19.669281005859375}, "layer_norm": {"bias": 1.9278013706207275, "scale": 24.389545440673828}}}, "pos_conv_embed": {"conv": {"bias": 5.617927551269531, "weight_g": 8.927756309509277, "weight_v": 85.1714096069336}}}, "feature_extractor": {"conv_layers": {"0": {"conv": {"bias": 2.0290679931640625, "kernel": 20.55536460876465}, "layer_norm": {"bias": 4.550922393798828, "scale": 16.167570114135742}}, "1": {"conv": {"bias": 1.7790228128433228, "kernel": 51.24136734008789}, "layer_norm": {"bias": 5.962646961212158, "scale": 23.268157958984375}}, "2": {"conv": {"bias": 1.140576720237732, "kernel": 46.50312042236328}, "layer_norm": {"bias": 4.176670551300049, "scale": 20.370853424072266}}, "3": {"conv": {"bias": 0.6725863218307495, "kernel": 44.397525787353516}, "layer_norm": {"bias": 3.888174533843994, "scale": 17.53795051574707}}, "4": {"conv": {"bias": 0.6373162269592285, "kernel": 41.314056396484375}, "layer_norm": {"bias": 2.385471820831299, "scale": 16.34571647644043}}, "5": {"conv": {"bias": 0.5147221684455872, "kernel": 37.479759216308594}, "layer_norm": {"bias": 2.020900011062622, "scale": 17.064470291137695}}, "6": {"conv": {"bias": 0.4947893023490906, "kernel": 40.64780044555664}, "layer_norm": {"bias": 0.5876954793930054, "scale": 19.058603286743164}}}}, "feature_projection": {"layer_norm": {"bias": 6.331979751586914, "scale": 16.56066131591797}, "projection": {"bias": 1.7012581825256348, "kernel": 35.14557647705078}}, "masked_spec_embed": 11.914372444152832}}, "train/learning_rate": 9.961725299945101e-05, "train/loss": 0.8391348123550415, "train/param_norm": 1192.62841796875, "_runtime": 8064, "_timestamp": 1659748575, "_step": 4900}
wandb/run-20220805_230151-2y71vcu4/logs/debug-internal.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1418c615d6b95002fa3e56523498468b78963e0fc761d193f43d63503270a8d6
3
+ size 676709
wandb/run-20220805_230151-2y71vcu4/logs/debug.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:920d65c1c30ad3dfc773632222a1eab21e7f75a6783383b410e1f6c3b2db661a
3
+ size 2637
wandb/run-20220805_230151-2y71vcu4/run-2y71vcu4.wandb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:95295b8d59f586b5d5d84e048c8d5ed02c82c7ecf568b944db54ed22b8ce655f
3
+ size 3936514