versae commited on
Commit
e3be764
1 Parent(s): 63f0838

j2u4n7h4: saving weights and logs of step 1k

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .cache +1 -0
  2. .gitattributes +1 -0
  3. added_tokens.json +4 -0
  4. config.json +109 -0
  5. events.out.tfevents.1659137001.t1v-n-eedfb410-w-0.2559923.0.v2 +3 -0
  6. events.out.tfevents.1659168077.t1v-n-eedfb410-w-0.2047809.0.v2 +3 -0
  7. events.out.tfevents.1659169464.t1v-n-eedfb410-w-0.1065426.0.v2 +3 -0
  8. events.out.tfevents.1659171045.t1v-n-eedfb410-w-0.86199.0.v2 +3 -0
  9. events.out.tfevents.1659174715.t1v-n-eedfb410-w-0.3333166.0.v2 +3 -0
  10. events.out.tfevents.1659181854.t1v-n-eedfb410-w-0.3085831.0.v2 +3 -0
  11. events.out.tfevents.1659182962.t1v-n-eedfb410-w-0.2099342.0.v2 +3 -0
  12. events.out.tfevents.1659184651.t1v-n-eedfb410-w-0.4852.0.v2 +3 -0
  13. events.out.tfevents.1659185790.t1v-n-eedfb410-w-0.3212038.0.v2 +3 -0
  14. events.out.tfevents.1659190065.t1v-n-eedfb410-w-0.2276371.0.v2 +3 -0
  15. events.out.tfevents.1659203868.t1v-n-eedfb410-w-0.1493173.0.v2 +3 -0
  16. flax_model.msgpack +3 -0
  17. models/__init__.py +6 -0
  18. models/__pycache__/__init__.cpython-38.pyc +0 -0
  19. models/__pycache__/configuration_bart.cpython-38.pyc +0 -0
  20. models/__pycache__/configuration_speech_encoder_decoder.cpython-38.pyc +0 -0
  21. models/__pycache__/configuration_wav2vec2.cpython-38.pyc +0 -0
  22. models/__pycache__/modeling_flax_bart.cpython-38.pyc +0 -0
  23. models/__pycache__/modeling_flax_speech_encoder_decoder.cpython-38.pyc +0 -0
  24. models/__pycache__/modeling_flax_wav2vec2.cpython-38.pyc +0 -0
  25. models/configuration_bart.py +183 -0
  26. models/configuration_speech_encoder_decoder.py +121 -0
  27. models/configuration_wav2vec2.py +344 -0
  28. models/modeling_flax_bart.py +816 -0
  29. models/modeling_flax_speech_encoder_decoder.py +1245 -0
  30. models/modeling_flax_wav2vec2.py +975 -0
  31. preprocessor_config.json +9 -0
  32. run.sh +48 -0
  33. run_flax_speech_recognition_ctc.py +1604 -0
  34. special_tokens_map.json +190 -0
  35. tokenizer_config.json +12 -0
  36. vocab.json +41 -0
  37. wandb/debug-internal.log +1 -0
  38. wandb/debug.log +1 -0
  39. wandb/latest-run +1 -0
  40. wandb/run-20220729_183213-356uc50u/files/code/run_flax_speech_recognition_ctc.py +1596 -0
  41. wandb/run-20220729_183213-356uc50u/files/config.yaml +33 -0
  42. wandb/run-20220729_183213-356uc50u/files/output.log +253 -0
  43. wandb/run-20220729_183213-356uc50u/files/requirements.txt +137 -0
  44. wandb/run-20220729_183213-356uc50u/files/wandb-metadata.json +67 -0
  45. wandb/run-20220729_183213-356uc50u/files/wandb-summary.json +1 -0
  46. wandb/run-20220729_183213-356uc50u/logs/debug-internal.log +301 -0
  47. wandb/run-20220729_183213-356uc50u/logs/debug.log +130 -0
  48. wandb/run-20220729_183213-356uc50u/run-356uc50u.wandb +3 -0
  49. wandb/run-20220729_184558-17ksemgv/files/code/run_flax_speech_recognition_ctc.py +1596 -0
  50. wandb/run-20220729_184558-17ksemgv/files/config.yaml +33 -0
.cache ADDED
@@ -0,0 +1 @@
 
 
1
+ /home/javierr/.cache
.gitattributes CHANGED
@@ -29,3 +29,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
29
  *.zip filter=lfs diff=lfs merge=lfs -text
30
  *.zstandard filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
29
  *.zip filter=lfs diff=lfs merge=lfs -text
30
  *.zstandard filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
32
+ *.wandb filter=lfs diff=lfs merge=lfs -text
added_tokens.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "</s>": 40,
3
+ "<s>": 39
4
+ }
config.json ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.055,
3
+ "adapter_kernel_size": 3,
4
+ "adapter_stride": 2,
5
+ "add_adapter": false,
6
+ "apply_spec_augment": true,
7
+ "architectures": [
8
+ "Wav2Vec2ForCTC"
9
+ ],
10
+ "attention_dropout": 0.094,
11
+ "bos_token_id": 1,
12
+ "classifier_proj_size": 256,
13
+ "codevector_dim": 1024,
14
+ "contrastive_logits_temperature": 0.1,
15
+ "conv_bias": true,
16
+ "conv_dim": [
17
+ 512,
18
+ 512,
19
+ 512,
20
+ 512,
21
+ 512,
22
+ 512,
23
+ 512
24
+ ],
25
+ "conv_kernel": [
26
+ 10,
27
+ 3,
28
+ 3,
29
+ 3,
30
+ 3,
31
+ 2,
32
+ 2
33
+ ],
34
+ "conv_stride": [
35
+ 5,
36
+ 2,
37
+ 2,
38
+ 2,
39
+ 2,
40
+ 2,
41
+ 2
42
+ ],
43
+ "ctc_loss_reduction": "mean",
44
+ "ctc_zero_infinity": true,
45
+ "diversity_loss_weight": 0.1,
46
+ "do_stable_layer_norm": true,
47
+ "eos_token_id": 2,
48
+ "feat_extract_activation": "gelu",
49
+ "feat_extract_dropout": 0.0,
50
+ "feat_extract_norm": "layer",
51
+ "feat_proj_dropout": 0.04,
52
+ "feat_quantizer_dropout": 0.0,
53
+ "final_dropout": 0.0,
54
+ "fuse_matmuls": false,
55
+ "gradient_checkpointing": true,
56
+ "hidden_act": "gelu",
57
+ "hidden_dropout": 0.047,
58
+ "hidden_size": 1280,
59
+ "initializer_range": 0.02,
60
+ "intermediate_size": 5120,
61
+ "layer_norm_eps": 1e-05,
62
+ "layerdrop": 0.041,
63
+ "mask_feature_length": 64,
64
+ "mask_feature_min_masks": 0,
65
+ "mask_feature_prob": 0.25,
66
+ "mask_time_length": 10,
67
+ "mask_time_min_masks": 2,
68
+ "mask_time_prob": 0.082,
69
+ "model_type": "wav2vec2",
70
+ "num_adapter_layers": 3,
71
+ "num_attention_heads": 16,
72
+ "num_codevector_groups": 2,
73
+ "num_codevectors_per_group": 320,
74
+ "num_conv_pos_embedding_groups": 16,
75
+ "num_conv_pos_embeddings": 128,
76
+ "num_feat_extract_layers": 7,
77
+ "num_hidden_layers": 48,
78
+ "num_negatives": 100,
79
+ "output_hidden_size": 1280,
80
+ "pad_token_id": 38,
81
+ "proj_codevector_dim": 1024,
82
+ "tdnn_dilation": [
83
+ 1,
84
+ 2,
85
+ 3,
86
+ 1,
87
+ 1
88
+ ],
89
+ "tdnn_dim": [
90
+ 512,
91
+ 512,
92
+ 512,
93
+ 512,
94
+ 1500
95
+ ],
96
+ "tdnn_kernel": [
97
+ 5,
98
+ 3,
99
+ 3,
100
+ 1,
101
+ 1
102
+ ],
103
+ "torch_dtype": "float32",
104
+ "transformers_version": "4.21.0",
105
+ "use_scan": false,
106
+ "use_weighted_layer_sum": false,
107
+ "vocab_size": 39,
108
+ "xvector_output_dim": 512
109
+ }
events.out.tfevents.1659137001.t1v-n-eedfb410-w-0.2559923.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ded7b585446de921ebfe6f76acb1100484ca4c264a7b066ab3fa8904e6e4b6e6
3
+ size 40
events.out.tfevents.1659168077.t1v-n-eedfb410-w-0.2047809.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0e93ee6816853f3923399b5d527d2a61ca34ce887c28517df99c05fb20a7e6c
3
+ size 40
events.out.tfevents.1659169464.t1v-n-eedfb410-w-0.1065426.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e59e7c0f2f3d5c03effbc941690f4567f87bf74e6d559526bbd802d65dc2710a
3
+ size 40
events.out.tfevents.1659171045.t1v-n-eedfb410-w-0.86199.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f89384c846e08fa515da01d4c2459ae230a9c6702cc6ab9dc8533217d5d44e0
3
+ size 40
events.out.tfevents.1659174715.t1v-n-eedfb410-w-0.3333166.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ac1687a5731ee3d8fe81c3a3a9cad3798a3cfd2de48cf0e700947ec667e5183
3
+ size 40
events.out.tfevents.1659181854.t1v-n-eedfb410-w-0.3085831.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c744a976d0f5879c6383cdbfc1c543b24f09469a2279b2297c378547f26f08c5
3
+ size 40
events.out.tfevents.1659182962.t1v-n-eedfb410-w-0.2099342.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c58c7b4399f1690c0cf07b003a6f00eda584715b170f4b5ba7ae5694af533262
3
+ size 40
events.out.tfevents.1659184651.t1v-n-eedfb410-w-0.4852.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:711a53c830c473bac46091e9ffea0b461decbed903088bb927b1c4a9331a21f9
3
+ size 40
events.out.tfevents.1659185790.t1v-n-eedfb410-w-0.3212038.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c9db04b6cba094a693bcea7d3e63f4af0128112bd11f0ca41d1dd0998cfc8a5
3
+ size 40
events.out.tfevents.1659190065.t1v-n-eedfb410-w-0.2276371.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2abfac3540820545dabe683f72071edc7efb99f2d83d2bfd2a0ccd92f388ba2
3
+ size 40
events.out.tfevents.1659203868.t1v-n-eedfb410-w-0.1493173.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a908c7234714cc88dd9694ee3edfa2051d03c68e9bf8d6ca30df43ae3bdb58ef
3
+ size 40
flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c0a8533250d85a6e1bccebd2a781b061e35d55e0a42a09491a2f29826c05a05
3
+ size 3850218852
models/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from models.configuration_bart import BartConfig
2
+ from models.configuration_wav2vec2 import Wav2Vec2Config
3
+ from models.configuration_speech_encoder_decoder import SpeechEncoderDecoderConfig
4
+ from models.modeling_flax_wav2vec2 import FlaxWav2Vec2Model, FlaxWav2Vec2Module, FlaxWav2Vec2ForCTC, FlaxWav2Vec2ForCTCModule
5
+ from models.modeling_flax_bart import FlaxBartForCausalLM, FlaxBartForCausalLMModule
6
+ from models.modeling_flax_speech_encoder_decoder import FlaxSpeechEncoderDecoderModel
models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (724 Bytes). View file
 
models/__pycache__/configuration_bart.cpython-38.pyc ADDED
Binary file (7.02 kB). View file
 
models/__pycache__/configuration_speech_encoder_decoder.cpython-38.pyc ADDED
Binary file (4.6 kB). View file
 
models/__pycache__/configuration_wav2vec2.cpython-38.pyc ADDED
Binary file (16.8 kB). View file
 
models/__pycache__/modeling_flax_bart.cpython-38.pyc ADDED
Binary file (21.1 kB). View file
 
models/__pycache__/modeling_flax_speech_encoder_decoder.cpython-38.pyc ADDED
Binary file (39.4 kB). View file
 
models/__pycache__/modeling_flax_wav2vec2.cpython-38.pyc ADDED
Binary file (30.6 kB). View file
 
models/configuration_bart.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ BART model configuration"""
16
+ import warnings
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
25
+ "facebook/bart-large": "https://huggingface.co/facebook/bart-large/resolve/main/config.json",
26
+ # See all BART models at https://huggingface.co/models?filter=bart
27
+ }
28
+
29
+
30
+ class BartConfig(PretrainedConfig):
31
+ r"""
32
+ This is the configuration class to store the configuration of a [`BartModel`]. It is used to instantiate a BART
33
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
34
+ defaults will yield a similar configuration to that of the BART
35
+ [facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture.
36
+
37
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
38
+ documentation from [`PretrainedConfig`] for more information.
39
+
40
+
41
+ Args:
42
+ vocab_size (`int`, *optional*, defaults to 50265):
43
+ Vocabulary size of the BART model. Defines the number of different tokens that can be represented by the
44
+ `inputs_ids` passed when calling [`BartModel`] or [`TFBartModel`].
45
+ d_model (`int`, *optional*, defaults to 1024):
46
+ Dimensionality of the layers and the pooler layer.
47
+ encoder_layers (`int`, *optional*, defaults to 12):
48
+ Number of encoder layers.
49
+ decoder_layers (`int`, *optional*, defaults to 12):
50
+ Number of decoder layers.
51
+ encoder_attention_heads (`int`, *optional*, defaults to 16):
52
+ Number of attention heads for each attention layer in the Transformer encoder.
53
+ decoder_attention_heads (`int`, *optional*, defaults to 16):
54
+ Number of attention heads for each attention layer in the Transformer decoder.
55
+ decoder_ffn_dim (`int`, *optional*, defaults to 4096):
56
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
57
+ encoder_ffn_dim (`int`, *optional*, defaults to 4096):
58
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
59
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
60
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
61
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
62
+ dropout (`float`, *optional*, defaults to 0.1):
63
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
64
+ attention_dropout (`float`, *optional*, defaults to 0.0):
65
+ The dropout ratio for the attention probabilities.
66
+ activation_dropout (`float`, *optional*, defaults to 0.0):
67
+ The dropout ratio for activations inside the fully connected layer.
68
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
69
+ The dropout ratio for classifier.
70
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
71
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
72
+ just in case (e.g., 512 or 1024 or 2048).
73
+ init_std (`float`, *optional*, defaults to 0.02):
74
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
75
+ encoder_layerdrop: (`float`, *optional*, defaults to 0.0):
76
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
77
+ for more details.
78
+ decoder_layerdrop: (`float`, *optional*, defaults to 0.0):
79
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
80
+ for more details.
81
+ scale_embedding (`bool`, *optional*, defaults to `False`):
82
+ Scale embeddings by diving by sqrt(d_model).
83
+ use_cache (`bool`, *optional*, defaults to `True`):
84
+ Whether or not the model should return the last key/values attentions (not used by all models).
85
+ num_labels: (`int`, *optional*, defaults to 3):
86
+ The number of labels to use in [`BartForSequenceClassification`].
87
+ forced_eos_token_id (`int`, *optional*, defaults to 2):
88
+ The id of the token to force as the last generated token when `max_length` is reached. Usually set to
89
+ `eos_token_id`.
90
+ use_scan (`bool`, *optional*, defaults to `False`):
91
+ Whether or not to use nn.scan in the Flax Bart attention layers.
92
+
93
+ Example:
94
+
95
+ ```python
96
+ >>> from transformers import BartModel, BartConfig
97
+
98
+ >>> # Initializing a BART facebook/bart-large style configuration
99
+ >>> configuration = BartConfig()
100
+
101
+ >>> # Initializing a model from the facebook/bart-large style configuration
102
+ >>> model = BartModel(configuration)
103
+
104
+ >>> # Accessing the model configuration
105
+ >>> configuration = model.config
106
+ ```"""
107
+ model_type = "bart"
108
+ keys_to_ignore_at_inference = ["past_key_values"]
109
+ attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
110
+
111
+ def __init__(
112
+ self,
113
+ vocab_size=50265,
114
+ max_position_embeddings=1024,
115
+ encoder_layers=12,
116
+ encoder_ffn_dim=4096,
117
+ encoder_attention_heads=16,
118
+ decoder_layers=12,
119
+ decoder_ffn_dim=4096,
120
+ decoder_attention_heads=16,
121
+ encoder_layerdrop=0.0,
122
+ decoder_layerdrop=0.0,
123
+ activation_function="gelu",
124
+ d_model=1024,
125
+ dropout=0.1,
126
+ attention_dropout=0.0,
127
+ activation_dropout=0.0,
128
+ init_std=0.02,
129
+ classifier_dropout=0.0,
130
+ scale_embedding=False,
131
+ use_cache=True,
132
+ use_scan=False,
133
+ fuse_matmuls=False,
134
+ num_labels=3,
135
+ pad_token_id=1,
136
+ bos_token_id=0,
137
+ eos_token_id=2,
138
+ is_encoder_decoder=True,
139
+ decoder_start_token_id=2,
140
+ forced_eos_token_id=2,
141
+ **kwargs
142
+ ):
143
+ self.vocab_size = vocab_size
144
+ self.max_position_embeddings = max_position_embeddings
145
+ self.d_model = d_model
146
+ self.encoder_ffn_dim = encoder_ffn_dim
147
+ self.encoder_layers = encoder_layers
148
+ self.encoder_attention_heads = encoder_attention_heads
149
+ self.decoder_ffn_dim = decoder_ffn_dim
150
+ self.decoder_layers = decoder_layers
151
+ self.decoder_attention_heads = decoder_attention_heads
152
+ self.dropout = dropout
153
+ self.attention_dropout = attention_dropout
154
+ self.activation_dropout = activation_dropout
155
+ self.activation_function = activation_function
156
+ self.init_std = init_std
157
+ self.encoder_layerdrop = encoder_layerdrop
158
+ self.decoder_layerdrop = decoder_layerdrop
159
+ self.classifier_dropout = classifier_dropout
160
+ self.use_cache = use_cache
161
+ self.use_scan = use_scan
162
+ self.fuse_matmuls = fuse_matmuls
163
+ self.num_hidden_layers = encoder_layers
164
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
165
+
166
+ super().__init__(
167
+ num_labels=num_labels,
168
+ pad_token_id=pad_token_id,
169
+ bos_token_id=bos_token_id,
170
+ eos_token_id=eos_token_id,
171
+ is_encoder_decoder=is_encoder_decoder,
172
+ decoder_start_token_id=decoder_start_token_id,
173
+ forced_eos_token_id=forced_eos_token_id,
174
+ **kwargs,
175
+ )
176
+
177
+ # ensure backward compatibility for BART CNN models
178
+ if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
179
+ self.forced_bos_token_id = self.bos_token_id
180
+ warnings.warn(
181
+ f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
182
+ "The config can simply be saved and uploaded again to be fixed."
183
+ )
models/configuration_speech_encoder_decoder.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import copy
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.utils import logging
21
+ from models.configuration_wav2vec2 import Wav2Vec2Config
22
+ from models.configuration_bart import BartConfig
23
+ from transformers import AutoConfig
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ class SpeechEncoderDecoderConfig(PretrainedConfig):
30
+ r"""
31
+ [`SpeechEncoderDecoderConfig`] is the configuration class to store the configuration of a
32
+ [`SpeechEncoderDecoderModel`]. It is used to instantiate an Encoder Decoder model according to the specified
33
+ arguments, defining the encoder and decoder configs.
34
+
35
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
+ documentation from [`PretrainedConfig`] for more information.
37
+
38
+ Args:
39
+ kwargs (*optional*):
40
+ Dictionary of keyword arguments. Notably:
41
+
42
+ - **encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
43
+ the encoder config.
44
+ - **decoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
45
+ the decoder config.
46
+
47
+ Examples:
48
+
49
+ ```python
50
+ >>> from transformers import BertConfig, Wav2Vec2Config, SpeechEncoderDecoderConfig, SpeechEncoderDecoderModel
51
+
52
+ >>> # Initializing a Wav2Vec2 & BERT style configuration
53
+ >>> config_encoder = Wav2Vec2Config()
54
+ >>> config_decoder = BertConfig()
55
+
56
+ >>> config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)
57
+
58
+ >>> # Initializing a Wav2Vec2Bert model from a Wav2Vec2 & bert-base-uncased style configurations
59
+ >>> model = SpeechEncoderDecoderModel(config=config)
60
+
61
+ >>> # Accessing the model configuration
62
+ >>> config_encoder = model.config.encoder
63
+ >>> config_decoder = model.config.decoder
64
+ >>> # set decoder config to causal lm
65
+ >>> config_decoder.is_decoder = True
66
+ >>> config_decoder.add_cross_attention = True
67
+
68
+ >>> # Saving the model, including its configuration
69
+ >>> model.save_pretrained("my-model")
70
+
71
+ >>> # loading model and config from pretrained folder
72
+ >>> encoder_decoder_config = SpeechEncoderDecoderConfig.from_pretrained("my-model")
73
+ >>> model = SpeechEncoderDecoderModel.from_pretrained("my-model", config=encoder_decoder_config)
74
+ ```"""
75
+ model_type = "speech-encoder-decoder"
76
+ is_composition = True
77
+
78
+ def __init__(self, **kwargs):
79
+ super().__init__(**kwargs)
80
+ if "encoder" not in kwargs or "decoder" not in kwargs:
81
+ raise ValueError(
82
+ f"A configuraton of type {self.model_type} cannot be instantiated because not both `encoder` and `decoder` sub-configurations are passed, but only {kwargs}"
83
+ )
84
+
85
+ encoder_config = kwargs.pop("encoder")
86
+ decoder_config = kwargs.pop("decoder")
87
+
88
+ # TODO: Load configs from AutoConfig (as done in Transformers 🤗)
89
+ self.encoder = Wav2Vec2Config(**encoder_config)
90
+ self.decoder = BartConfig(**decoder_config)
91
+ self.is_encoder_decoder = True
92
+
93
+ @classmethod
94
+ def from_encoder_decoder_configs(
95
+ cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
96
+ ) -> PretrainedConfig:
97
+ r"""
98
+ Instantiate a [`SpeechEncoderDecoderConfig`] (or a derived class) from a pre-trained encoder model
99
+ configuration and decoder model configuration.
100
+
101
+ Returns:
102
+ [`SpeechEncoderDecoderConfig`]: An instance of a configuration object
103
+ """
104
+ logger.info("Setting `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
105
+ decoder_config.is_decoder = True
106
+ decoder_config.add_cross_attention = True
107
+
108
+ return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)
109
+
110
+ def to_dict(self):
111
+ """
112
+ Serializes this instance to a Python dictionary. Override the default *to_dict()* from *PretrainedConfig*.
113
+
114
+ Returns:
115
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
116
+ """
117
+ output = copy.deepcopy(self.__dict__)
118
+ output["encoder"] = self.encoder.to_dict()
119
+ output["decoder"] = self.decoder.to_dict()
120
+ output["model_type"] = self.__class__.model_type
121
+ return output
models/configuration_wav2vec2.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Wav2Vec2 model configuration"""
16
+
17
+ import functools
18
+ import operator
19
+
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.utils import logging
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
27
+ "facebook/wav2vec2-base-960h": "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/config.json",
28
+ # See all Wav2Vec2 models at https://huggingface.co/models?filter=wav2vec2
29
+ }
30
+
31
+
32
+ class Wav2Vec2Config(PretrainedConfig):
33
+ r"""
34
+ This is the configuration class to store the configuration of a [`Wav2Vec2Model`]. It is used to instantiate an
35
+ Wav2Vec2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
36
+ with the defaults will yield a similar configuration to that of the Wav2Vec2
37
+ [facebook/wav2vec2-base-960h](https://huggingface.co/facebook/wav2vec2-base-960h) architecture.
38
+
39
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
40
+ documentation from [`PretrainedConfig`] for more information.
41
+
42
+
43
+ Args:
44
+ vocab_size (`int`, *optional*, defaults to 32):
45
+ Vocabulary size of the Wav2Vec2 model. Defines the number of different tokens that can be represented by
46
+ the `inputs_ids` passed when calling [`Wav2Vec2Model`] or [`TFWav2Vec2Model`]. Vocabulary size of the
47
+ model. Defines the different tokens that can be represented by the *inputs_ids* passed to the forward
48
+ method of [`Wav2Vec2Model`].
49
+ hidden_size (`int`, *optional*, defaults to 768):
50
+ Dimensionality of the encoder layers and the pooler layer.
51
+ num_hidden_layers (`int`, *optional*, defaults to 12):
52
+ Number of hidden layers in the Transformer encoder.
53
+ num_attention_heads (`int`, *optional*, defaults to 12):
54
+ Number of attention heads for each attention layer in the Transformer encoder.
55
+ intermediate_size (`int`, *optional*, defaults to 3072):
56
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
57
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
58
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
59
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
60
+ hidden_dropout (`float`, *optional*, defaults to 0.1):
61
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
62
+ attention_dropout (`float`, *optional*, defaults to 0.1):
63
+ The dropout ratio for the attention probabilities.
64
+ final_dropout (`float`, *optional*, defaults to 0.1):
65
+ The dropout probability for the final projection layer of [`Wav2Vec2ForCTC`].
66
+ initializer_range (`float`, *optional*, defaults to 0.02):
67
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
68
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
69
+ The epsilon used by the layer normalization layers.
70
+ feat_extract_norm (`str`, *optional*, defaults to `"group"`):
71
+ The norm to be applied to 1D convolutional layers in feature encoder. One of `"group"` for group
72
+ normalization of only the first 1D convolutional layer or `"layer"` for layer normalization of all 1D
73
+ convolutional layers.
74
+ feat_proj_dropout (`float`, *optional*, defaults to 0.0):
75
+ The dropout probability for output of the feature encoder.
76
+ feat_extract_activation (`str, `optional`, defaults to `"gelu"`):
77
+ The non-linear activation function (function or string) in the 1D convolutional layers of the feature
78
+ extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
79
+ feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
80
+ The dropout probabilitiy for quantized feature encoder states.
81
+ conv_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
82
+ A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
83
+ feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
84
+ conv_stride (`Tuple[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
85
+ A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
86
+ of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*.
87
+ conv_kernel (`Tuple[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
88
+ A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
89
+ length of *conv_kernel* defines the number of convolutional layers and has to match the length of
90
+ *conv_dim*.
91
+ conv_bias (`bool`, *optional*, defaults to `False`):
92
+ Whether the 1D convolutional layers have a bias.
93
+ num_conv_pos_embeddings (`int`, *optional*, defaults to 128):
94
+ Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional
95
+ embeddings layer.
96
+ num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
97
+ Number of groups of 1D convolutional positional embeddings layer.
98
+ do_stable_layer_norm (`bool`, *optional*, defaults to `False`):
99
+ Whether to apply *stable* layer norm architecture of the Transformer encoder. `do_stable_layer_norm is
100
+ True` corresponds to applying layer norm before the attention layer, whereas `do_stable_layer_norm is
101
+ False` corresponds to applying layer norm after the attention layer.
102
+ apply_spec_augment (`bool`, *optional*, defaults to `True`):
103
+ Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see
104
+ [SpecAugment: A Simple Data Augmentation Method for Automatic Speech
105
+ Recognition](https://arxiv.org/abs/1904.08779).
106
+ mask_time_prob (`float`, *optional*, defaults to 0.05):
107
+ Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking
108
+ procecure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If
109
+ reasoning from the propability of each feature vector to be chosen as the start of the vector span to be
110
+ masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the
111
+ actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`.
112
+ mask_time_length (`int`, *optional*, defaults to 10):
113
+ Length of vector span along the time axis.
114
+ mask_time_min_masks (`int`, *optional*, defaults to 2),:
115
+ The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step,
116
+ irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length <
117
+ mask_time_min_masks''
118
+ mask_feature_prob (`float`, *optional*, defaults to 0.0):
119
+ Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
120
+ masking procecure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over
121
+ the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector
122
+ span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
123
+ may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is
124
+ True`.
125
+ mask_feature_length (`int`, *optional*, defaults to 10):
126
+ Length of vector span along the feature axis.
127
+ mask_feature_min_masks (`int`, *optional*, defaults to 0),:
128
+ The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
129
+ step, irrespectively of `mask_feature_prob`. Only relevant if
130
+ ''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks''
131
+ num_codevectors_per_group (`int`, *optional*, defaults to 320):
132
+ Number of entries in each quantization codebook (group).
133
+ num_codevector_groups (`int`, *optional*, defaults to 2):
134
+ Number of codevector groups for product codevector quantization.
135
+ contrastive_logits_temperature (`float`, *optional*, defaults to 0.1):
136
+ The temperature *kappa* in the contrastive loss.
137
+ feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
138
+ The dropout probabilitiy for the output of the feature encoder that's used by the quantizer.
139
+ num_negatives (`int`, *optional*, defaults to 100):
140
+ Number of negative samples for the contrastive loss.
141
+ codevector_dim (`int`, *optional*, defaults to 256):
142
+ Dimensionality of the quantized feature vectors.
143
+ proj_codevector_dim (`int`, *optional*, defaults to 256):
144
+ Dimensionality of the final projection of both the quantized and the transformer features.
145
+ diversity_loss_weight (`int`, *optional*, defaults to 0.1):
146
+ The weight of the codebook diversity loss component.
147
+ ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`):
148
+ Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
149
+ instance of [`Wav2Vec2ForCTC`].
150
+ ctc_zero_infinity (`bool`, *optional*, defaults to `False`):
151
+ Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
152
+ occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
153
+ of [`Wav2Vec2ForCTC`].
154
+ use_weighted_layer_sum (`bool`, *optional*, defaults to `False`):
155
+ Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
156
+ instance of [`Wav2Vec2ForSequenceClassification`].
157
+ classifier_proj_size (`int`, *optional*, defaults to 256):
158
+ Dimensionality of the projection before token mean-pooling for classification.
159
+ tdnn_dim (`Tuple[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
160
+ A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN*
161
+ module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers.
162
+ tdnn_kernel (`Tuple[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
163
+ A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the
164
+ *XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*.
165
+ tdnn_dilation (`Tuple[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
166
+ A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the
167
+ *XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*.
168
+ xvector_output_dim (`int`, *optional*, defaults to 512):
169
+ Dimensionality of the *XVector* embedding vectors.
170
+ add_adapter (`bool`, *optional*, defaults to `False`):
171
+ Whether a convolutional network should be stacked on top of the Wav2Vec2 Encoder. Can be very useful for
172
+ warm-starting Wav2Vec2 for SpeechEncoderDecoder models.
173
+ adapter_kernel_size (`int`, *optional*, defaults to 3):
174
+ Kernel size of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
175
+ adapter_stride (`int`, *optional*, defaults to 2):
176
+ Stride of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
177
+ num_adapter_layers (`int`, *optional*, defaults to 3):
178
+ Number of convolutional layers that should be used in the adapter network. Only relevant if `add_adapter is
179
+ True`.
180
+ output_hidden_size (`int`, *optional*):
181
+ Dimensionality of the encoder output layer. If not defined, this defaults to *hidden-size*. Only relevant
182
+ if `add_adapter is True`.
183
+ use_scan (`bool`, *optional*, defaults to `False`):
184
+ Whether or not to use nn.scan in the Flax Wav2Vec2 transformer layers.
185
+
186
+ Example:
187
+
188
+ ```python
189
+ >>> from transformers import Wav2Vec2Model, Wav2Vec2Config
190
+
191
+ >>> # Initializing a Wav2Vec2 facebook/wav2vec2-base-960h style configuration
192
+ >>> configuration = Wav2Vec2Config()
193
+
194
+ >>> # Initializing a model from the facebook/wav2vec2-base-960h style configuration
195
+ >>> model = Wav2Vec2Model(configuration)
196
+
197
+ >>> # Accessing the model configuration
198
+ >>> configuration = model.config
199
+ ```"""
200
+ model_type = "wav2vec2"
201
+
202
+ def __init__(
203
+ self,
204
+ vocab_size=32,
205
+ hidden_size=768,
206
+ num_hidden_layers=12,
207
+ num_attention_heads=12,
208
+ intermediate_size=3072,
209
+ hidden_act="gelu",
210
+ hidden_dropout=0.1,
211
+ activation_dropout=0.1,
212
+ attention_dropout=0.1,
213
+ feat_proj_dropout=0.0,
214
+ feat_quantizer_dropout=0.0,
215
+ final_dropout=0.1,
216
+ layerdrop=0.1,
217
+ initializer_range=0.02,
218
+ layer_norm_eps=1e-5,
219
+ feat_extract_norm="group",
220
+ feat_extract_activation="gelu",
221
+ conv_dim=(512, 512, 512, 512, 512, 512, 512),
222
+ conv_stride=(5, 2, 2, 2, 2, 2, 2),
223
+ conv_kernel=(10, 3, 3, 3, 3, 2, 2),
224
+ conv_bias=False,
225
+ num_conv_pos_embeddings=128,
226
+ num_conv_pos_embedding_groups=16,
227
+ do_stable_layer_norm=False,
228
+ apply_spec_augment=True,
229
+ mask_time_prob=0.05,
230
+ mask_time_length=10,
231
+ mask_time_min_masks=2,
232
+ mask_feature_prob=0.0,
233
+ mask_feature_length=10,
234
+ mask_feature_min_masks=0,
235
+ num_codevectors_per_group=320,
236
+ num_codevector_groups=2,
237
+ contrastive_logits_temperature=0.1,
238
+ num_negatives=100,
239
+ codevector_dim=256,
240
+ proj_codevector_dim=256,
241
+ diversity_loss_weight=0.1,
242
+ ctc_loss_reduction="sum",
243
+ ctc_zero_infinity=False,
244
+ use_weighted_layer_sum=False,
245
+ classifier_proj_size=256,
246
+ tdnn_dim=(512, 512, 512, 512, 1500),
247
+ tdnn_kernel=(5, 3, 3, 1, 1),
248
+ tdnn_dilation=(1, 2, 3, 1, 1),
249
+ xvector_output_dim=512,
250
+ pad_token_id=0,
251
+ bos_token_id=1,
252
+ eos_token_id=2,
253
+ add_adapter=False,
254
+ adapter_kernel_size=3,
255
+ adapter_stride=2,
256
+ num_adapter_layers=3,
257
+ output_hidden_size=None,
258
+ use_scan=False,
259
+ fuse_matmuls=False,
260
+ **kwargs
261
+ ):
262
+ super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
263
+ self.hidden_size = hidden_size
264
+ self.feat_extract_norm = feat_extract_norm
265
+ self.feat_extract_activation = feat_extract_activation
266
+ self.conv_dim = list(conv_dim)
267
+ self.conv_stride = list(conv_stride)
268
+ self.conv_kernel = list(conv_kernel)
269
+ self.conv_bias = conv_bias
270
+ self.num_conv_pos_embeddings = num_conv_pos_embeddings
271
+ self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
272
+ self.num_feat_extract_layers = len(self.conv_dim)
273
+ self.num_hidden_layers = num_hidden_layers
274
+ self.intermediate_size = intermediate_size
275
+ self.hidden_act = hidden_act
276
+ self.num_attention_heads = num_attention_heads
277
+ self.hidden_dropout = hidden_dropout
278
+ self.attention_dropout = attention_dropout
279
+ self.activation_dropout = activation_dropout
280
+ self.feat_proj_dropout = feat_proj_dropout
281
+ self.final_dropout = final_dropout
282
+ self.layerdrop = layerdrop
283
+ self.layer_norm_eps = layer_norm_eps
284
+ self.initializer_range = initializer_range
285
+ self.vocab_size = vocab_size
286
+ self.do_stable_layer_norm = do_stable_layer_norm
287
+ self.use_weighted_layer_sum = use_weighted_layer_sum
288
+ self.use_scan = use_scan
289
+ self.fuse_matmuls = fuse_matmuls
290
+
291
+ if (
292
+ (len(self.conv_stride) != self.num_feat_extract_layers)
293
+ or (len(self.conv_kernel) != self.num_feat_extract_layers)
294
+ or (len(self.conv_dim) != self.num_feat_extract_layers)
295
+ ):
296
+ raise ValueError(
297
+ "Configuration for convolutional layers is incorrect. "
298
+ "It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`, "
299
+ f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride) "
300
+ f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`."
301
+ )
302
+
303
+ # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
304
+ self.apply_spec_augment = apply_spec_augment
305
+ self.mask_time_prob = mask_time_prob
306
+ self.mask_time_length = mask_time_length
307
+ self.mask_time_min_masks = mask_time_min_masks
308
+ self.mask_feature_prob = mask_feature_prob
309
+ self.mask_feature_length = mask_feature_length
310
+ self.mask_feature_min_masks = mask_feature_min_masks
311
+
312
+ # parameters for pretraining with codevector quantized representations
313
+ self.num_codevectors_per_group = num_codevectors_per_group
314
+ self.num_codevector_groups = num_codevector_groups
315
+ self.contrastive_logits_temperature = contrastive_logits_temperature
316
+ self.feat_quantizer_dropout = feat_quantizer_dropout
317
+ self.num_negatives = num_negatives
318
+ self.codevector_dim = codevector_dim
319
+ self.proj_codevector_dim = proj_codevector_dim
320
+ self.diversity_loss_weight = diversity_loss_weight
321
+
322
+ # ctc loss
323
+ self.ctc_loss_reduction = ctc_loss_reduction
324
+ self.ctc_zero_infinity = ctc_zero_infinity
325
+
326
+ # adapter
327
+ self.add_adapter = add_adapter
328
+ self.adapter_kernel_size = adapter_kernel_size
329
+ self.adapter_stride = adapter_stride
330
+ self.num_adapter_layers = num_adapter_layers
331
+ self.output_hidden_size = output_hidden_size or hidden_size
332
+
333
+ # SequenceClassification-specific parameter. Feel free to ignore for other classes.
334
+ self.classifier_proj_size = classifier_proj_size
335
+
336
+ # XVector-specific parameters. Feel free to ignore for other classes.
337
+ self.tdnn_dim = list(tdnn_dim)
338
+ self.tdnn_kernel = list(tdnn_kernel)
339
+ self.tdnn_dilation = list(tdnn_dilation)
340
+ self.xvector_output_dim = xvector_output_dim
341
+
342
+ @property
343
+ def inputs_to_logits_ratio(self):
344
+ return functools.reduce(operator.mul, self.conv_stride, 1)
models/modeling_flax_bart.py ADDED
@@ -0,0 +1,816 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Flax Bart model."""
16
+
17
+ import math
18
+ import random
19
+ from functools import partial
20
+ from typing import Optional, Tuple
21
+
22
+ import numpy as np
23
+
24
+ import flax.linen as nn
25
+ import jax
26
+ import jax.numpy as jnp
27
+ from flax.core.frozen_dict import FrozenDict, unfreeze
28
+ from flax.linen import combine_masks, make_causal_mask
29
+ from flax.linen import partitioning as nn_partitioning
30
+ from flax.linen.attention import dot_product_attention_weights
31
+ from jax import lax
32
+ from jax.random import PRNGKey
33
+
34
+ from transformers.modeling_flax_outputs import (
35
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
36
+ FlaxCausalLMOutputWithCrossAttentions,
37
+ )
38
+ from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
39
+
40
+ from models import BartConfig
41
+
42
+
43
+ scan_with_axes = nn_partitioning.scan_with_axes
44
+ remat = nn_partitioning.remat
45
+
46
+
47
+ def shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
48
+ """
49
+ Shift input ids one token to the right.
50
+ """
51
+ shifted_input_ids = np.zeros_like(input_ids)
52
+ shifted_input_ids[:, 1:] = input_ids[:, :-1]
53
+ shifted_input_ids[:, 0] = decoder_start_token_id
54
+
55
+ shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
56
+ return shifted_input_ids
57
+
58
+
59
+ class FlaxBartAttention(nn.Module):
60
+ config: BartConfig
61
+ embed_dim: int
62
+ num_heads: int
63
+ dropout: float = 0.0
64
+ causal: bool = False
65
+ bias: bool = True
66
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
67
+
68
+ def setup(self) -> None:
69
+ self.head_dim = self.embed_dim // self.num_heads
70
+ if self.head_dim * self.num_heads != self.embed_dim:
71
+ raise ValueError(
72
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
73
+ f" and `num_heads`: {self.num_heads})."
74
+ )
75
+
76
+ dense = partial(
77
+ nn.Dense,
78
+ self.embed_dim,
79
+ use_bias=self.bias,
80
+ dtype=self.dtype,
81
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
82
+ )
83
+
84
+ self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
85
+
86
+ self.fused_proj = nn.Dense(
87
+ self.embed_dim * 3,
88
+ use_bias=self.bias,
89
+ dtype=self.dtype,
90
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
91
+ )
92
+
93
+ self.fused_key_value = nn.Dense(
94
+ self.embed_dim * 2,
95
+ use_bias=self.bias,
96
+ dtype=self.dtype,
97
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
98
+ )
99
+
100
+ self.out_proj = dense()
101
+
102
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
103
+
104
+ if self.causal:
105
+ self.causal_mask = make_causal_mask(
106
+ jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
107
+ )
108
+
109
+ def _split_heads(self, hidden_states):
110
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
111
+
112
+ def _merge_heads(self, hidden_states):
113
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
114
+
115
+ @nn.compact
116
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
117
+ """
118
+ This function takes projected key, value states from a single input token and concatenates the states to cached
119
+ states from previous steps. This function is slighly adapted from the official Flax repository:
120
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
121
+ """
122
+ # detect if we're initializing by absence of existing cache data.
123
+ is_initialized = self.has_variable("cache", "cached_key")
124
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
125
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
126
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
127
+
128
+ if is_initialized:
129
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
130
+ # update key, value caches with our new 1d spatial slices
131
+ cur_index = cache_index.value
132
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
133
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
134
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
135
+ cached_key.value = key
136
+ cached_value.value = value
137
+ num_updated_cache_vectors = query.shape[1]
138
+ cache_index.value = cache_index.value + num_updated_cache_vectors
139
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
140
+ pad_mask = jnp.broadcast_to(
141
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
142
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
143
+ )
144
+ attention_mask = combine_masks(pad_mask, attention_mask)
145
+ return key, value, attention_mask
146
+
147
+ def __call__(
148
+ self,
149
+ hidden_states: jnp.ndarray,
150
+ key_value_states: Optional[jnp.ndarray] = None,
151
+ attention_mask: Optional[jnp.ndarray] = None,
152
+ init_cache: bool = False,
153
+ deterministic: bool = True,
154
+ ) -> Tuple[jnp.ndarray]:
155
+ """Input shape: Batch x Time x Channel"""
156
+
157
+ # if key_value_states are provided this layer is used as a cross-attention layer
158
+ # for the decoder
159
+ is_cross_attention = key_value_states is not None
160
+ batch_size = hidden_states.shape[0]
161
+
162
+ if self.config.fuse_matmuls:
163
+ # get key, value proj
164
+ if is_cross_attention:
165
+ # get query proj
166
+ query_states = self.q_proj(hidden_states)
167
+ # cross_attentions
168
+ attention_states = self.fused_key_value(key_value_states)
169
+ key_states, value_states = jnp.split(attention_states, 2, axis=-1)
170
+ else:
171
+ attention_states = self.fused_proj(hidden_states)
172
+ query_states, key_states, value_states = jnp.split(attention_states, 3, axis=-1)
173
+
174
+ else:
175
+ # get query proj
176
+ query_states = self.q_proj(hidden_states)
177
+ # get key, value proj
178
+ if is_cross_attention:
179
+ # cross_attentions
180
+ key_states = self.k_proj(key_value_states)
181
+ value_states = self.v_proj(key_value_states)
182
+ else:
183
+ # self_attention
184
+ key_states = self.k_proj(hidden_states)
185
+ value_states = self.v_proj(hidden_states)
186
+
187
+ query_states = self._split_heads(query_states)
188
+ key_states = self._split_heads(key_states)
189
+ value_states = self._split_heads(value_states)
190
+
191
+ # handle cache prepare causal attention mask
192
+ if self.causal:
193
+ query_length, key_length = query_states.shape[1], key_states.shape[1]
194
+ if self.has_variable("cache", "cached_key"):
195
+ mask_shift = self.variables["cache"]["cache_index"]
196
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
197
+ causal_mask = lax.dynamic_slice(
198
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
199
+ )
200
+ else:
201
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
202
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
203
+
204
+ # combine masks if needed
205
+ if attention_mask is not None and self.causal:
206
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
207
+ attention_mask = combine_masks(attention_mask, causal_mask)
208
+ elif self.causal:
209
+ attention_mask = causal_mask
210
+ elif attention_mask is not None:
211
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
212
+
213
+ # During fast autoregressive decoding, we feed one position at a time,
214
+ # and cache the keys and values step by step.
215
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
216
+ key_states, value_states, attention_mask = self._concatenate_to_cache(
217
+ key_states, value_states, query_states, attention_mask
218
+ )
219
+
220
+ # Convert the boolean attention mask to an attention bias.
221
+ if attention_mask is not None:
222
+ # attention mask in the form of attention bias
223
+ attention_bias = lax.select(
224
+ attention_mask > 0,
225
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
226
+ jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
227
+ )
228
+ else:
229
+ attention_bias = None
230
+
231
+ dropout_rng = None
232
+ if not deterministic and self.dropout > 0.0:
233
+ dropout_rng = self.make_rng("dropout")
234
+
235
+ attn_weights = dot_product_attention_weights(
236
+ query_states,
237
+ key_states,
238
+ bias=attention_bias,
239
+ dropout_rng=dropout_rng,
240
+ dropout_rate=self.dropout,
241
+ broadcast_dropout=True,
242
+ deterministic=deterministic,
243
+ dtype=self.dtype,
244
+ precision=None,
245
+ )
246
+
247
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
248
+ attn_output = self._merge_heads(attn_output)
249
+ attn_output = self.out_proj(attn_output)
250
+
251
+ return attn_output, attn_weights
252
+
253
+
254
+ class FlaxBartDecoderLayer(nn.Module):
255
+ config: BartConfig
256
+ dtype: jnp.dtype = jnp.float32
257
+
258
+ def setup(self) -> None:
259
+ self.embed_dim = self.config.d_model
260
+ self.self_attn = FlaxBartAttention(
261
+ config=self.config,
262
+ embed_dim=self.embed_dim,
263
+ num_heads=self.config.decoder_attention_heads,
264
+ dropout=self.config.attention_dropout,
265
+ causal=True,
266
+ dtype=self.dtype,
267
+ )
268
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
269
+ self.activation_fn = ACT2FN[self.config.activation_function]
270
+ self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
271
+
272
+ self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
273
+ self.encoder_attn = FlaxBartAttention(
274
+ config=self.config,
275
+ embed_dim=self.embed_dim,
276
+ num_heads=self.config.decoder_attention_heads,
277
+ dropout=self.config.attention_dropout,
278
+ dtype=self.dtype,
279
+ )
280
+ self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
281
+ self.fc1 = nn.Dense(
282
+ self.config.encoder_ffn_dim,
283
+ dtype=self.dtype,
284
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
285
+ )
286
+ self.fc2 = nn.Dense(
287
+ self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
288
+ )
289
+ self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
290
+
291
+ def __call__(
292
+ self,
293
+ hidden_states: jnp.ndarray,
294
+ attention_mask: jnp.ndarray,
295
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
296
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
297
+ init_cache: bool = False,
298
+ output_attentions: bool = True,
299
+ deterministic: bool = True,
300
+ ) -> Tuple[jnp.ndarray]:
301
+
302
+ if self.config.use_scan:
303
+ hidden_states = hidden_states[0]
304
+
305
+ residual = hidden_states
306
+
307
+ # Self Attention
308
+ hidden_states, self_attn_weights = self.self_attn(
309
+ hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache
310
+ )
311
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
312
+ hidden_states = residual + hidden_states
313
+ hidden_states = self.self_attn_layer_norm(hidden_states)
314
+
315
+ # Cross-Attention Block
316
+ cross_attn_weights = None
317
+ if encoder_hidden_states is not None:
318
+ residual = hidden_states
319
+
320
+ hidden_states, cross_attn_weights = self.encoder_attn(
321
+ hidden_states=hidden_states,
322
+ key_value_states=encoder_hidden_states,
323
+ attention_mask=encoder_attention_mask,
324
+ )
325
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
326
+ hidden_states = residual + hidden_states
327
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
328
+
329
+ # Fully Connected
330
+ residual = hidden_states
331
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
332
+ hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
333
+ hidden_states = self.fc2(hidden_states)
334
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
335
+ hidden_states = residual + hidden_states
336
+ hidden_states = self.final_layer_norm(hidden_states)
337
+
338
+ outputs = (hidden_states,)
339
+
340
+ if output_attentions:
341
+ outputs += (self_attn_weights, cross_attn_weights)
342
+
343
+ if self.config.use_scan:
344
+ outputs = (outputs, None)
345
+
346
+ return outputs
347
+
348
+
349
+ class FlaxBartDecoderLayerCollection(nn.Module):
350
+ config: BartConfig
351
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
352
+
353
+ @nn.compact
354
+ def __call__(
355
+ self,
356
+ hidden_states,
357
+ attention_mask,
358
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
359
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
360
+ deterministic: bool = True,
361
+ init_cache: bool = False,
362
+ output_attentions: bool = False,
363
+ output_hidden_states: bool = False,
364
+ return_dict: bool = True,
365
+ ):
366
+ # decoder layers
367
+ all_hidden_states = () if output_hidden_states else None
368
+ all_self_attns = () if output_attentions else None
369
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
370
+
371
+ num_decoder_layers = self.config.decoder_layers
372
+ BlockDecoderLayer = (
373
+ remat(
374
+ FlaxBartDecoderLayer,
375
+ static_argnums=(4, 5, 6),
376
+ prevent_cse=not self.config.use_scan,
377
+ )
378
+ if self.config.gradient_checkpointing
379
+ else FlaxBartDecoderLayer
380
+ )
381
+
382
+ if self.config.use_scan:
383
+ # since all decoder layers are the same, we use nn.scan directly
384
+ assert not output_attentions, "cannot use `scan` with `output_attentions` set to `True`"
385
+ assert not output_hidden_states, "cannot use `scan` with `output_hidden_states` set to `True`"
386
+ hidden_states = (hidden_states,)
387
+
388
+ # TODO: add layerdrop in checkpointed scan (note: default value for layerdrop in config is zero)
389
+ hidden_states, _ = scan_with_axes(
390
+ BlockDecoderLayer,
391
+ variable_axes={"params": 0, "cache": 0},
392
+ split_rngs={"params": True, "dropout": True},
393
+ in_axes=(nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast),
394
+ length=num_decoder_layers,
395
+ )(self.config, dtype=self.dtype, name="FlaxBartDecoderLayers")(
396
+ hidden_states,
397
+ attention_mask,
398
+ encoder_hidden_states,
399
+ encoder_attention_mask,
400
+ init_cache,
401
+ output_attentions,
402
+ deterministic,
403
+ )
404
+ hidden_states = hidden_states[0]
405
+
406
+ else:
407
+ for layer in range(num_decoder_layers):
408
+ if output_hidden_states:
409
+ all_hidden_states += (hidden_states,)
410
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
411
+ dropout_probability = random.uniform(0, 1)
412
+ if not deterministic and (dropout_probability < self.config.decoder_layerdrop):
413
+ layer_outputs = (None, None, None)
414
+ else:
415
+ layer_outputs = BlockDecoderLayer(self.config, dtype=self.dtype, name=str(layer),)(
416
+ hidden_states,
417
+ attention_mask,
418
+ encoder_hidden_states,
419
+ encoder_attention_mask,
420
+ init_cache,
421
+ output_attentions,
422
+ deterministic,
423
+ )
424
+
425
+ hidden_states = layer_outputs[0]
426
+ if output_attentions:
427
+ all_self_attns += (layer_outputs[1],)
428
+
429
+ if encoder_hidden_states is not None:
430
+ all_cross_attentions += (layer_outputs[2],)
431
+
432
+ # add hidden states from the last decoder layer
433
+ if output_hidden_states:
434
+ all_hidden_states += (hidden_states,)
435
+
436
+ outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions]
437
+
438
+ if not return_dict:
439
+ return tuple(v for v in outputs if v is not None)
440
+
441
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
442
+ last_hidden_state=hidden_states,
443
+ hidden_states=all_hidden_states,
444
+ attentions=all_self_attns,
445
+ cross_attentions=all_cross_attentions,
446
+ )
447
+
448
+
449
+ class FlaxBartDecoder(nn.Module):
450
+ config: BartConfig
451
+ embed_tokens: nn.Embed
452
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
453
+
454
+ def setup(self):
455
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
456
+
457
+ embed_dim = self.config.d_model
458
+ self.padding_idx = self.config.pad_token_id
459
+ self.max_target_positions = self.config.max_position_embeddings
460
+ self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
461
+
462
+ # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
463
+ # and adjust num_embeddings appropriately. Other models don't have this hack
464
+ self.offset = 2
465
+ self.embed_positions = nn.Embed(
466
+ self.config.max_position_embeddings + self.offset,
467
+ embed_dim,
468
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
469
+ )
470
+
471
+ self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
472
+ self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
473
+
474
+ def __call__(
475
+ self,
476
+ input_ids,
477
+ attention_mask,
478
+ position_ids,
479
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
480
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
481
+ init_cache: bool = False,
482
+ output_attentions: bool = False,
483
+ output_hidden_states: bool = False,
484
+ return_dict: bool = True,
485
+ deterministic: bool = True,
486
+ ):
487
+ input_shape = input_ids.shape
488
+ input_ids = input_ids.reshape(-1, input_shape[-1])
489
+
490
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
491
+
492
+ # embed positions
493
+ positions = self.embed_positions(position_ids + self.offset)
494
+
495
+ hidden_states = inputs_embeds + positions
496
+ hidden_states = self.layernorm_embedding(hidden_states)
497
+
498
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
499
+
500
+ outputs = self.layers(
501
+ hidden_states,
502
+ attention_mask,
503
+ encoder_hidden_states,
504
+ encoder_attention_mask,
505
+ deterministic=deterministic,
506
+ init_cache=init_cache,
507
+ output_attentions=output_attentions,
508
+ output_hidden_states=output_hidden_states,
509
+ return_dict=return_dict,
510
+ )
511
+
512
+ if not return_dict:
513
+ return outputs
514
+
515
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
516
+ last_hidden_state=outputs.last_hidden_state,
517
+ hidden_states=outputs.hidden_states,
518
+ attentions=outputs.attentions,
519
+ cross_attentions=outputs.cross_attentions,
520
+ )
521
+
522
+
523
+ class FlaxBartDecoderPreTrainedModel(FlaxPreTrainedModel):
524
+ config_class = BartConfig
525
+ base_model_prefix: str = "model"
526
+ module_class: nn.Module = None
527
+
528
+ def __init__(
529
+ self,
530
+ config: BartConfig,
531
+ input_shape: Tuple[int] = (1, 1),
532
+ seed: int = 0,
533
+ dtype: jnp.dtype = jnp.float32,
534
+ _do_init: bool = True,
535
+ **kwargs
536
+ ):
537
+ config.is_decoder = True
538
+ config.is_encoder_decoder = False
539
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
540
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
541
+
542
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
543
+ # init input tensors
544
+ input_ids = jnp.zeros(input_shape, dtype="i4")
545
+ attention_mask = jnp.ones_like(input_ids)
546
+
547
+ batch_size, sequence_length = input_ids.shape
548
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
549
+
550
+ params_rng, dropout_rng = jax.random.split(rng)
551
+ rngs = {"params": params_rng, "dropout": dropout_rng}
552
+ encoder_hidden_states = jnp.zeros(input_shape + (self.config.d_model,))
553
+ encoder_attention_mask = attention_mask
554
+ module_init_outputs = self.module.init(
555
+ rngs,
556
+ input_ids,
557
+ attention_mask,
558
+ position_ids,
559
+ encoder_hidden_states,
560
+ encoder_attention_mask,
561
+ return_dict=False,
562
+ )
563
+ return module_init_outputs["params"]
564
+
565
+ def init_cache(self, batch_size, max_length):
566
+ r"""
567
+ Args:
568
+ batch_size (`int`):
569
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
570
+ max_length (`int`):
571
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
572
+ cache.
573
+ """
574
+ # init input variables to retrieve cache
575
+ input_ids = jnp.ones((batch_size, max_length), dtype="i4")
576
+ attention_mask = jnp.ones_like(input_ids, dtype="i4")
577
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
578
+
579
+ init_variables = self.module.init(
580
+ jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
581
+ )
582
+ return unfreeze(init_variables["cache"])
583
+
584
+ def __call__(
585
+ self,
586
+ input_ids: jnp.ndarray,
587
+ attention_mask: Optional[jnp.ndarray] = None,
588
+ position_ids: Optional[jnp.ndarray] = None,
589
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
590
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
591
+ output_attentions: Optional[bool] = None,
592
+ output_hidden_states: Optional[bool] = None,
593
+ return_dict: Optional[bool] = None,
594
+ train: bool = False,
595
+ params: dict = None,
596
+ past_key_values: dict = None,
597
+ dropout_rng: PRNGKey = None,
598
+ ):
599
+ """
600
+ Args:
601
+ input_ids (`jnp.ndarray` of shape `(target_batch_size, target_sequence_length)`):
602
+ Indices of decoder input sequence tokens in the vocabulary.
603
+
604
+ Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
605
+ [`PreTrainedTokenizer.__call__`] for details.
606
+
607
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
608
+
609
+ For translation and summarization training, `decoder_input_ids` should be provided. If no
610
+ `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
611
+ for denoising pre-training following the paper.
612
+ attention_mask (`jnp.ndarray` of shape `(target_batch_size, target_sequence_length)`, *optional*):
613
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
614
+ be used by default.
615
+
616
+ If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the
617
+ paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
618
+ position_ids (`numpy.ndarray` of shape `(target_batch_size, sequence_length)`, *optional*):
619
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
620
+ range `[0, config.max_position_embeddings - 1]`.
621
+ encoder_hidden_states (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
622
+ A sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
623
+ encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
624
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
625
+
626
+ - 1 for tokens that are **not masked**,
627
+ - 0 for tokens that are **masked**.
628
+
629
+ [What are attention masks?](../glossary#attention-mask)
630
+ output_attentions (`bool`, *optional*):
631
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
632
+ tensors for more detail.
633
+ output_hidden_states (`bool`, *optional*):
634
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
635
+ more detail.
636
+ return_dict (`bool`, *optional*):
637
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
638
+ past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
639
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
640
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
641
+ """
642
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
643
+ output_hidden_states = (
644
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
645
+ )
646
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
647
+
648
+ if encoder_hidden_states is not None and encoder_attention_mask is None:
649
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
650
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
651
+
652
+ # prepare decoder inputs
653
+ if attention_mask is None:
654
+ attention_mask = jnp.ones_like(input_ids)
655
+ if position_ids is None:
656
+ batch_size, sequence_length = input_ids.shape
657
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
658
+
659
+ # Handle any PRNG if needed
660
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
661
+
662
+ inputs = {"params": params or self.params}
663
+
664
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
665
+ # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
666
+ # changed by FlaxBartAttention module
667
+ if past_key_values:
668
+ inputs["cache"] = past_key_values
669
+ mutable = ["cache"]
670
+ else:
671
+ mutable = False
672
+
673
+ outputs = self.module.apply(
674
+ inputs,
675
+ input_ids=jnp.array(input_ids, dtype="i4"),
676
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
677
+ position_ids=jnp.array(position_ids, dtype="i4"),
678
+ encoder_hidden_states=encoder_hidden_states,
679
+ encoder_attention_mask=encoder_attention_mask,
680
+ output_attentions=output_attentions,
681
+ output_hidden_states=output_hidden_states,
682
+ return_dict=return_dict,
683
+ deterministic=not train,
684
+ rngs=rngs,
685
+ mutable=mutable,
686
+ )
687
+
688
+ # add updated cache to model output
689
+ if past_key_values is not None and return_dict:
690
+ outputs, past_key_values = outputs
691
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
692
+ return outputs
693
+ elif past_key_values is not None and not return_dict:
694
+ outputs, past_key_values = outputs
695
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
696
+
697
+ return outputs
698
+
699
+
700
+ class FlaxBartDecoderWrapper(nn.Module):
701
+ """
702
+ This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
703
+ used in combination with the [`EncoderDecoderModel`] framework.
704
+ """
705
+
706
+ config: BartConfig
707
+ dtype: jnp.dtype = jnp.float32
708
+
709
+ def setup(self):
710
+ embed_dim = self.config.d_model
711
+ embed_tokens = nn.Embed(
712
+ self.config.vocab_size,
713
+ embed_dim,
714
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
715
+ )
716
+ self.decoder = FlaxBartDecoder(config=self.config, embed_tokens=embed_tokens, dtype=self.dtype)
717
+
718
+ def __call__(self, *args, **kwargs):
719
+ return self.decoder(*args, **kwargs)
720
+
721
+
722
+ class FlaxBartForCausalLMModule(nn.Module):
723
+ """Bart Decoder Module with a language modeling head on top (linear layer with weights tied to the input embeddings)
724
+ e.g. for autoregressive tasks.
725
+ """
726
+
727
+ config: BartConfig
728
+ dtype: jnp.dtype = jnp.float32
729
+
730
+ def setup(self):
731
+ self.model = FlaxBartDecoderWrapper(config=self.config, dtype=self.dtype)
732
+ self.lm_head = nn.Dense(
733
+ self.config.vocab_size,
734
+ use_bias=False,
735
+ dtype=self.dtype,
736
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
737
+ )
738
+
739
+ def __call__(
740
+ self,
741
+ input_ids,
742
+ attention_mask,
743
+ position_ids,
744
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
745
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
746
+ init_cache: bool = False,
747
+ output_attentions: bool = False,
748
+ output_hidden_states: bool = False,
749
+ return_dict: bool = True,
750
+ deterministic: bool = True,
751
+ ):
752
+
753
+ outputs = self.model(
754
+ input_ids,
755
+ attention_mask,
756
+ position_ids,
757
+ encoder_hidden_states,
758
+ encoder_attention_mask,
759
+ deterministic=deterministic,
760
+ init_cache=init_cache,
761
+ output_attentions=output_attentions,
762
+ output_hidden_states=output_hidden_states,
763
+ return_dict=return_dict,
764
+ )
765
+
766
+ hidden_states = outputs[0]
767
+
768
+ if self.config.tie_word_embeddings:
769
+ shared_embedding = self.model.variables["params"]["decoder"]["embed_tokens"]["embedding"]
770
+ lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
771
+ else:
772
+ lm_logits = self.lm_head(hidden_states)
773
+
774
+ if not return_dict:
775
+ return (lm_logits,) + outputs[1:]
776
+
777
+ return FlaxCausalLMOutputWithCrossAttentions(
778
+ logits=lm_logits,
779
+ hidden_states=outputs.hidden_states,
780
+ attentions=outputs.attentions,
781
+ cross_attentions=outputs.cross_attentions,
782
+ )
783
+
784
+
785
+ class FlaxBartForCausalLM(FlaxBartDecoderPreTrainedModel):
786
+ """Bart Decoder Model with a language modeling head on top (linear layer with weights tied to the input embeddings)
787
+ e.g. for autoregressive tasks.
788
+ """
789
+
790
+ module_class = FlaxBartForCausalLMModule
791
+
792
+ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
793
+ # initializing the cache
794
+ batch_size, seq_length = input_ids.shape
795
+
796
+ past_key_values = self.init_cache(batch_size, max_length)
797
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
798
+ # But since the decoder uses a causal mask, those positions are masked anyway.
799
+ # Thus, we can create a single static attention_mask here, which is more efficient for compilation
800
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
801
+ if attention_mask is not None:
802
+ position_ids = attention_mask.cumsum(axis=-1) - 1
803
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
804
+ else:
805
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
806
+
807
+ return {
808
+ "past_key_values": past_key_values,
809
+ "attention_mask": extended_attention_mask,
810
+ "position_ids": position_ids,
811
+ }
812
+
813
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
814
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
815
+ model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
816
+ return model_kwargs
models/modeling_flax_speech_encoder_decoder.py ADDED
@@ -0,0 +1,1245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Classes to support Flax Speech-Encoder-Decoder architectures"""
16
+
17
+ import os
18
+ from functools import partial
19
+ from typing import Optional, Tuple, Union, Dict
20
+
21
+ import flax
22
+ import flax.linen as nn
23
+ import jax
24
+ import jax.numpy as jnp
25
+ from flax.core.frozen_dict import FrozenDict, unfreeze
26
+ from jax import lax
27
+ from jax.random import PRNGKey
28
+ import numpy as np
29
+
30
+ from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput
31
+ from transformers.modeling_flax_utils import FlaxPreTrainedModel
32
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, ModelOutput
33
+ from transformers.generation_flax_utils import FlaxLogitsProcessorList
34
+ from models import (
35
+ FlaxWav2Vec2Model,
36
+ FlaxWav2Vec2Module,
37
+ FlaxBartForCausalLM,
38
+ FlaxBartForCausalLMModule,
39
+ BartConfig,
40
+ Wav2Vec2Config,
41
+ SpeechEncoderDecoderConfig,
42
+ )
43
+
44
+ logger = logging.get_logger(__name__)
45
+
46
+ _CONFIG_FOR_DOC = "SpeechEncoderDecoderConfig"
47
+
48
+ SPEECH_ENCODER_DECODER_START_DOCSTRING = r"""
49
+ This class can be used to initialize a speech-sequence-to-text-sequence model with any pretrained speech
50
+ autoencoding model as the encoder and any pretrained text autoregressive model as the decoder. The encoder is
51
+ loaded via [`~AutoModel.from_pretrained`] function and the decoder is loaded via
52
+ [`~AutoModelForCausalLM.from_pretrained`] function. Cross-attention layers are automatically added to the decoder
53
+ and should be fine-tuned on a downstream generative task, like summarization.
54
+
55
+ The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation
56
+ tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation
57
+ Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi
58
+ Zhou, Wei Li, Peter J. Liu.
59
+
60
+ Additionally, in [Large-Scale Self- and Semi-Supervised Learning for Speech
61
+ Translation](https://arxiv.org/abs/2104.06678) it is shown how leveraging large pretrained speech models for speech
62
+ translation yields a significant performance improvement.
63
+
64
+ After such an Speech-Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other
65
+ models (see the examples for more information).
66
+
67
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
68
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
69
+ etc.)
70
+
71
+ This model is also a Flax Linen
72
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
73
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
74
+
75
+ Parameters:
76
+ config ([`SpeechEncoderDecoderConfig`]): Model configuration class with all the parameters of the model.
77
+ Initializing with a config file does not load the weights associated with the model, only the
78
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
79
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
80
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
81
+ `jax.numpy.bfloat16` (on TPUs).
82
+
83
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
84
+ specified all the computation will be performed with the given `dtype`.
85
+
86
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
87
+ parameters.**
88
+
89
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
90
+ [`~FlaxPreTrainedModel.to_bf16`].
91
+ """
92
+
93
+ SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
94
+ Args:
95
+ inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*):
96
+ Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac*
97
+ or *.wav* audio file into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile
98
+ library (*pip install soundfile*). To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or
99
+ [`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type
100
+ *torch.FloatTensor*.
101
+ attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
102
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
103
+
104
+ - 1 for tokens that are **not masked**,
105
+ - 0 for tokens that are **masked**.
106
+
107
+ [What are attention masks?](../glossary#attention-mask)
108
+ decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
109
+ Indices of decoder input sequence tokens in the vocabulary.
110
+
111
+ Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
112
+ [`PreTrainedTokenizer.__call__`] for details.
113
+
114
+ [What are input IDs?](../glossary#input-ids)
115
+
116
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
117
+ `past_key_values`).
118
+
119
+ For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
120
+ created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
121
+ and prepending them with the `decoder_start_token_id`.
122
+ decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
123
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
124
+ be used by default.
125
+ decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
126
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
127
+ range `[0, config.decoder.max_position_embeddings - 1]`.
128
+ output_hidden_states (`bool`, *optional*):
129
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
130
+ more detail.
131
+ return_dict (`bool`, *optional*):
132
+ If set to `True`, the model will return a [`~utils.FlaxSeq2SeqLMOutput`] instead of a plain tuple.
133
+ """
134
+
135
+ SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING = r"""
136
+ Args:
137
+ inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*):
138
+ Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac*
139
+ or *.wav* audio file into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile
140
+ library (*pip install soundfile*). To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or
141
+ [`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type
142
+ *torch.FloatTensor*.
143
+ attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
144
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
145
+
146
+ - 1 for tokens that are **not masked**,
147
+ - 0 for tokens that are **masked**.
148
+
149
+ [What are attention masks?](../glossary#attention-mask)
150
+ output_attentions (`bool`, *optional*):
151
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
152
+ tensors for more detail.
153
+ output_hidden_states (`bool`, *optional*):
154
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
155
+ more detail.
156
+ return_dict (`bool`, *optional*):
157
+ If set to `True`, the model will return a [`~utils.FlaxBaseModelOutput`] instead of a plain tuple.
158
+ """
159
+
160
+ SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r"""
161
+ Args:
162
+ decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
163
+ Indices of decoder input sequence tokens in the vocabulary.
164
+
165
+ Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
166
+ [`PreTrainedTokenizer.__call__`] for details.
167
+
168
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
169
+
170
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
171
+ `past_key_values`).
172
+
173
+ For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
174
+ created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
175
+ and prepending them with the `decoder_start_token_id`.
176
+ encoder_outputs (`tuple(tuple(jnp.ndarray)`):
177
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
178
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
179
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
180
+ encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
181
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
182
+
183
+ - 1 for tokens that are **not masked**,
184
+ - 0 for tokens that are **masked**.
185
+
186
+ [What are attention masks?](../glossary#attention-mask)
187
+ decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
188
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
189
+ be used by default.
190
+ decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
191
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
192
+ range `[0, config.decoder.max_position_embeddings - 1]`.
193
+ past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
194
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
195
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
196
+ output_attentions (`bool`, *optional*):
197
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
198
+ tensors for more detail.
199
+ output_hidden_states (`bool`, *optional*):
200
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
201
+ more detail.
202
+ return_dict (`bool`, *optional*):
203
+ If set to `True`, the model will return a [`~utils.FlaxCausalLMOutputWithCrossAttentions`] instead of a
204
+ plain tuple.
205
+ """
206
+
207
+ @flax.struct.dataclass
208
+ class FlaxBeamSearchOutput(ModelOutput):
209
+ """
210
+ Flax Base class for outputs of decoder-only generation models using greedy search.
211
+
212
+
213
+ Args:
214
+ sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
215
+ The generated sequences.
216
+ scores (`jnp.ndarray` of shape `(batch_size,)`):
217
+ The scores (log probabilites) of the generated sequences.
218
+ """
219
+
220
+ sequences: jnp.ndarray = None
221
+ scores: jnp.ndarray = None
222
+
223
+
224
+ @flax.struct.dataclass
225
+ class BeamSearchState:
226
+ cur_len: jnp.ndarray
227
+ running_sequences: jnp.ndarray
228
+ running_scores: jnp.ndarray
229
+ sequences: jnp.ndarray
230
+ scores: jnp.ndarray
231
+ is_sent_finished: jnp.ndarray
232
+ model_kwargs: Dict[str, jnp.ndarray]
233
+
234
+
235
+
236
+
237
+ class FlaxSpeechEncoderDecoderModule(nn.Module):
238
+ config: SpeechEncoderDecoderConfig
239
+ dtype: jnp.dtype = jnp.float32
240
+
241
+ def setup(self):
242
+ encoder_config = self.config.encoder
243
+ decoder_config = self.config.decoder
244
+
245
+ # TODO: configure FlaxAutoModel mappings (required when trialling different encoder-decoder combinations)
246
+ encoder_module = FlaxWav2Vec2Module
247
+ decoder_module = FlaxBartForCausalLMModule
248
+
249
+ self.encoder = encoder_module(encoder_config, dtype=self.dtype)
250
+ self.decoder = decoder_module(decoder_config, dtype=self.dtype)
251
+
252
+ # encoder outputs might need to be projected to different dimension for decoder
253
+ if (
254
+ self.encoder.config.hidden_size != self.decoder.config.hidden_size
255
+ and self.decoder.config.cross_attention_hidden_size is None
256
+ ):
257
+ self.enc_to_dec_proj = nn.Dense(
258
+ self.decoder.config.hidden_size,
259
+ kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range),
260
+ dtype=self.dtype,
261
+ )
262
+ else:
263
+ self.enc_to_dec_proj = None
264
+
265
+ def _get_feat_extract_output_lengths(
266
+ self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
267
+ ):
268
+ """
269
+ Computes the output length of the convolutional layers
270
+ """
271
+
272
+ add_adapter = self.config.encoder.add_adapter if add_adapter is None else add_adapter
273
+
274
+ def _conv_out_length(input_length, kernel_size, stride):
275
+ # 1D convolutional layer output length formula taken
276
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
277
+ return (input_length - kernel_size) // stride + 1
278
+
279
+ for kernel_size, stride in zip(self.config.encoder.conv_kernel, self.config.encoder.conv_stride):
280
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
281
+
282
+ if add_adapter:
283
+ for _ in range(self.config.encoder.num_adapter_layers):
284
+ input_lengths = _conv_out_length(input_lengths, 1, self.config.encoder.adapter_stride)
285
+
286
+ return input_lengths
287
+
288
+ def _get_encoder_module(self):
289
+ return self.encoder
290
+
291
+ def _get_projection_module(self):
292
+ return self.enc_to_dec_proj
293
+
294
+ def _get_decoder_module(self):
295
+ return self.decoder
296
+
297
+ def __call__(
298
+ self,
299
+ inputs,
300
+ attention_mask,
301
+ decoder_input_ids,
302
+ decoder_attention_mask,
303
+ decoder_position_ids,
304
+ encoder_outputs=None,
305
+ extract_features=None,
306
+ output_attentions: bool = False,
307
+ output_hidden_states: bool = False,
308
+ output_features: bool = False,
309
+ return_dict: bool = True,
310
+ deterministic: bool = True,
311
+ freeze_feature_encoder: bool = False,
312
+ ):
313
+ if encoder_outputs is None:
314
+ encoder_outputs = self.encoder(
315
+ inputs,
316
+ attention_mask=attention_mask,
317
+ extract_features=extract_features,
318
+ output_attentions=output_attentions,
319
+ output_hidden_states=output_hidden_states,
320
+ output_features=output_features,
321
+ return_dict=return_dict,
322
+ deterministic=deterministic,
323
+ freeze_feature_encoder=freeze_feature_encoder,
324
+ )
325
+
326
+ if output_features:
327
+ return encoder_outputs
328
+
329
+ encoder_hidden_states = encoder_outputs[0]
330
+
331
+ # optionally project encoder_hidden_states
332
+ if self.enc_to_dec_proj is not None:
333
+ encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
334
+
335
+ # compute correct encoder attention mask
336
+ if attention_mask is not None:
337
+ encoder_attention_mask = self.encoder._get_feature_vector_attention_mask(
338
+ encoder_hidden_states.shape[1], attention_mask
339
+ )
340
+ else:
341
+ encoder_attention_mask = None
342
+
343
+ # flax script modeling_flax_wav2vec2.py
344
+ decoder_outputs = self.decoder(
345
+ input_ids=decoder_input_ids,
346
+ attention_mask=decoder_attention_mask,
347
+ position_ids=decoder_position_ids,
348
+ encoder_hidden_states=encoder_hidden_states,
349
+ encoder_attention_mask=encoder_attention_mask,
350
+ output_attentions=output_attentions,
351
+ output_hidden_states=output_hidden_states,
352
+ return_dict=return_dict,
353
+ deterministic=deterministic,
354
+ )
355
+
356
+ if not return_dict:
357
+ return decoder_outputs + encoder_outputs
358
+
359
+ return FlaxSeq2SeqLMOutput(
360
+ logits=decoder_outputs.logits,
361
+ decoder_hidden_states=decoder_outputs.hidden_states,
362
+ decoder_attentions=decoder_outputs.attentions,
363
+ cross_attentions=decoder_outputs.cross_attentions,
364
+ encoder_last_hidden_state=encoder_hidden_states,
365
+ encoder_hidden_states=encoder_outputs.hidden_states,
366
+ encoder_attentions=encoder_outputs.attentions,
367
+ )
368
+
369
+
370
+ @add_start_docstrings(SPEECH_ENCODER_DECODER_START_DOCSTRING)
371
+ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
372
+ r"""
373
+ [`FlaxSpeechEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture
374
+ with the module (flax.nn.Module) of one of the base model classes of the library as encoder module and another one
375
+ as decoder module when created with the :meth*~transformers.FlaxAutoModel.from_pretrained* class method for the
376
+ encoder and :meth*~transformers.FlaxAutoModelForCausalLM.from_pretrained* class method for the decoder.
377
+ """
378
+
379
+ config_class = SpeechEncoderDecoderConfig
380
+ base_model_prefix: str = "speech_encoder_decoder"
381
+ module_class = FlaxSpeechEncoderDecoderModule
382
+
383
+ def __init__(
384
+ self,
385
+ config: SpeechEncoderDecoderConfig,
386
+ input_shape: Optional[Tuple] = None,
387
+ seed: int = 0,
388
+ dtype: jnp.dtype = jnp.float32,
389
+ _do_init: bool = True,
390
+ **kwargs
391
+ ):
392
+
393
+ if not _do_init:
394
+ raise ValueError(
395
+ "`FlaxSpeechEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`."
396
+ )
397
+
398
+ if config.decoder.cross_attention_hidden_size is not None:
399
+ # Raise ValueError or option to project enc to dec hidden_size (eg EncAdapterLayer)
400
+ if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
401
+ raise ValueError(
402
+ "If `cross_attention_hidden_size` is specified in the decoder's configuration, "
403
+ "it has to be equal to the encoder's `hidden_size`. "
404
+ f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
405
+ f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
406
+ )
407
+
408
+ # make sure input & output embeddings are not tied
409
+ config.tie_word_embeddings = False
410
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
411
+
412
+ if input_shape is None:
413
+ # speech encoders almost always downsample the sequence length dimension
414
+ encoder_input_length = 1024
415
+ decoder_input_length = module._get_feat_extract_output_lengths(encoder_input_length)
416
+ input_shape = ((1, encoder_input_length), (1, decoder_input_length))
417
+
418
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
419
+
420
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
421
+ encoder_input_shape, decoder_input_shape = input_shape
422
+
423
+ # init input DeviceArrays
424
+ inputs = jnp.zeros(encoder_input_shape, dtype="f4")
425
+ attention_mask = jnp.ones_like(inputs, dtype="i4")
426
+ decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4")
427
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
428
+
429
+ batch_size, sequence_length = inputs.shape
430
+
431
+ decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape
432
+ if not decoder_batch_size == batch_size:
433
+ raise ValueError(
434
+ f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder and {decoder_batch_size} for decoder."
435
+ )
436
+ decoder_position_ids = jnp.broadcast_to(
437
+ jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length)
438
+ )
439
+
440
+ params_rng, dropout_rng = jax.random.split(rng)
441
+ rngs = {"params": params_rng, "dropout": dropout_rng}
442
+
443
+ return self.module.init(
444
+ rngs,
445
+ inputs,
446
+ attention_mask,
447
+ decoder_input_ids,
448
+ decoder_attention_mask,
449
+ decoder_position_ids,
450
+ )["params"]
451
+
452
+ def init_cache(self, batch_size, max_length, encoder_outputs):
453
+ r"""
454
+ Args:
455
+ batch_size (`int`):
456
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
457
+ max_length (`int`):
458
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
459
+ cache.
460
+ encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
461
+ `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
462
+ `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
463
+ is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
464
+ cross-attention of the decoder.
465
+ """
466
+ # init input variables to retrieve cache
467
+ decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
468
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
469
+ decoder_position_ids = jnp.broadcast_to(
470
+ jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
471
+ )
472
+
473
+ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
474
+ decoder_module = module._get_decoder_module()
475
+ return decoder_module(
476
+ input_ids=decoder_input_ids,
477
+ attention_mask=decoder_attention_mask,
478
+ position_ids=decoder_position_ids,
479
+ **kwargs,
480
+ )
481
+
482
+ init_variables = self.module.init(
483
+ jax.random.PRNGKey(0),
484
+ decoder_input_ids=decoder_input_ids,
485
+ decoder_attention_mask=decoder_attention_mask,
486
+ decoder_position_ids=decoder_position_ids,
487
+ encoder_hidden_states=encoder_outputs[0],
488
+ init_cache=True,
489
+ method=_decoder_forward, # we only need to call the decoder to init the cache
490
+ )
491
+ return unfreeze(init_variables["cache"])
492
+
493
+ def _get_feat_extract_output_lengths(
494
+ self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
495
+ ):
496
+ return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter)
497
+
498
+ @add_start_docstrings(SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING)
499
+ @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=_CONFIG_FOR_DOC)
500
+ def encode(
501
+ self,
502
+ inputs: jnp.ndarray,
503
+ attention_mask: Optional[jnp.ndarray] = None,
504
+ extract_features: Optional[jnp.ndarray] = None,
505
+ output_attentions: Optional[bool] = None,
506
+ output_hidden_states: Optional[bool] = None,
507
+ output_features: Optional[bool] = None,
508
+ return_dict: Optional[bool] = None,
509
+ train: bool = False,
510
+ freeze_feature_encoder: bool = False,
511
+ params: dict = None,
512
+ dropout_rng: PRNGKey = None,
513
+ ):
514
+ r"""
515
+ Returns:
516
+
517
+ Example:
518
+
519
+ ```python
520
+ >>> from transformers import FlaxSpeechEncoderDecoderModel
521
+
522
+ >>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized
523
+ >>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
524
+ ... "facebook/wav2vec2-large-lv60", "facebook/bart-large"
525
+ ... )
526
+
527
+ >>> inputs = jnp.ones((2, 5000), dtype=jnp.float32)
528
+ >>> encoder_outputs = model.encode(inputs)
529
+ ```"""
530
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
531
+ output_hidden_states = (
532
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
533
+ )
534
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
535
+
536
+ if attention_mask is None:
537
+ attention_mask = jnp.ones_like(inputs, dtype="i4")
538
+
539
+ if extract_features is not None:
540
+ extract_features = jnp.array(extract_features, dtype="f4")
541
+
542
+ # Handle any PRNG if needed
543
+ rngs = {}
544
+ if dropout_rng is not None:
545
+ rngs["dropout"] = dropout_rng
546
+
547
+ def _encoder_forward(module, inputs, attention_mask, **kwargs):
548
+ encode_module = module._get_encoder_module()
549
+ return encode_module(inputs, attention_mask, **kwargs)
550
+
551
+ outputs = self.module.apply(
552
+ {"params": params or self.params},
553
+ inputs=jnp.array(inputs, dtype="f4"),
554
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
555
+ extract_features=extract_features,
556
+ output_attentions=output_attentions,
557
+ output_hidden_states=output_hidden_states,
558
+ output_features=output_features,
559
+ return_dict=return_dict,
560
+ deterministic=not train,
561
+ freeze_feature_encoder=freeze_feature_encoder,
562
+ rngs=rngs,
563
+ method=_encoder_forward,
564
+ )
565
+
566
+ if return_dict and not output_features:
567
+ outputs = FlaxBaseModelOutput(
568
+ last_hidden_state=outputs.last_hidden_state,
569
+ hidden_states=outputs.hidden_states,
570
+ attentions=outputs.attentions,
571
+ )
572
+
573
+ return outputs
574
+
575
+ @add_start_docstrings(SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING)
576
+ @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
577
+ def decode(
578
+ self,
579
+ decoder_input_ids,
580
+ encoder_outputs,
581
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
582
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
583
+ decoder_position_ids: Optional[jnp.ndarray] = None,
584
+ past_key_values: dict = None,
585
+ output_attentions: Optional[bool] = None,
586
+ output_hidden_states: Optional[bool] = None,
587
+ return_dict: Optional[bool] = None,
588
+ train: bool = False,
589
+ params: dict = None,
590
+ dropout_rng: PRNGKey = None,
591
+ ):
592
+ r"""
593
+ Returns:
594
+
595
+ Example:
596
+
597
+ ```python
598
+ >>> from transformers import FlaxSpeechEncoderDecoderModel
599
+ >>> import jax.numpy as jnp
600
+
601
+ >>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized
602
+ >>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
603
+ ... "facebook/wav2vec2-large-lv60", "facebook/bart-large"
604
+ ... )
605
+
606
+ >>> inputs = jnp.ones((2, 5000), dtype=jnp.float32)
607
+ >>> encoder_outputs = model.encode(inputs)
608
+
609
+ >>> decoder_start_token_id = model.config.decoder.bos_token_id
610
+ >>> decoder_input_ids = jnp.ones((inputs.shape[0], 1), dtype="i4") * decoder_start_token_id
611
+
612
+ >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
613
+ >>> logits = outputs.logits
614
+ ```"""
615
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
616
+ output_hidden_states = (
617
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
618
+ )
619
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
620
+
621
+ encoder_hidden_states = encoder_outputs[0]
622
+ if encoder_attention_mask is None:
623
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
624
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
625
+
626
+ batch_size, sequence_length = decoder_input_ids.shape
627
+ if decoder_attention_mask is None:
628
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length))
629
+
630
+ if decoder_position_ids is None:
631
+ if past_key_values is not None:
632
+ raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
633
+
634
+ decoder_position_ids = jnp.broadcast_to(
635
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
636
+ )
637
+
638
+ # Handle any PRNG if needed
639
+ rngs = {}
640
+ if dropout_rng is not None:
641
+ rngs["dropout"] = dropout_rng
642
+
643
+ params = {"params": params or self.params}
644
+
645
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
646
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
647
+ # it can be changed by FlaxBartAttention module
648
+ if past_key_values:
649
+ params["cache"] = past_key_values
650
+ mutable = ["cache"]
651
+ else:
652
+ mutable = False
653
+
654
+ def _decoder_forward(
655
+ module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, encoder_hidden_states, **kwargs
656
+ ):
657
+
658
+ projection_module = module._get_projection_module()
659
+ decoder_module = module._get_decoder_module()
660
+
661
+ # optionally project encoder_hidden_states
662
+ if projection_module is not None:
663
+ encoder_hidden_states = projection_module(encoder_hidden_states)
664
+
665
+ return decoder_module(
666
+ decoder_input_ids,
667
+ decoder_attention_mask,
668
+ decoder_position_ids,
669
+ encoder_hidden_states,
670
+ **kwargs,
671
+ )
672
+
673
+ outputs = self.module.apply(
674
+ params,
675
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
676
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
677
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
678
+ encoder_hidden_states=encoder_hidden_states,
679
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
680
+ output_attentions=output_attentions,
681
+ output_hidden_states=output_hidden_states,
682
+ return_dict=return_dict,
683
+ deterministic=not train,
684
+ rngs=rngs,
685
+ mutable=mutable,
686
+ method=_decoder_forward,
687
+ )
688
+
689
+ # add updated cache to model output
690
+ if past_key_values is not None and return_dict:
691
+ outputs, past = outputs
692
+ outputs["past_key_values"] = unfreeze(past["cache"])
693
+ return outputs
694
+ elif past_key_values is not None and not return_dict:
695
+ outputs, past = outputs
696
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
697
+
698
+ return outputs
699
+
700
+ @add_start_docstrings_to_model_forward(SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING)
701
+ @replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
702
+ def __call__(
703
+ self,
704
+ inputs: jnp.ndarray,
705
+ attention_mask: Optional[jnp.ndarray] = None,
706
+ extract_features: Optional[jnp.ndarray] = None,
707
+ decoder_input_ids: Optional[jnp.ndarray] = None,
708
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
709
+ decoder_position_ids: Optional[jnp.ndarray] = None,
710
+ output_attentions: Optional[bool] = None,
711
+ output_hidden_states: Optional[bool] = None,
712
+ output_features: Optional[bool] = None,
713
+ return_dict: Optional[bool] = None,
714
+ train: bool = False,
715
+ freeze_feature_encoder: bool = False,
716
+ params: dict = None,
717
+ dropout_rng: PRNGKey = None,
718
+ ):
719
+ r"""
720
+ Returns:
721
+
722
+ Examples:
723
+
724
+ ```python
725
+ >>> from transformers import FlaxSpeechEncoderDecoderModel, BartTokenizer
726
+
727
+ >>> # load a fine-tuned wav2vec2-2-bart model
728
+ >>> model = FlaxSpeechEncoderDecoderModel.from_pretrained("patrickvonplaten/wav2vec2-2-bart-large")
729
+ >>> # load output tokenizer
730
+ >>> tokenizer_output = BartTokenizer.from_pretrained("facebook/bart-large")
731
+
732
+ >>> inputs = jnp.ones((2, 5000), dtype=jnp.float32)
733
+
734
+ >>> # use bart's special bos, pad and eos tokens
735
+ >>> model.config.decoder_start_token_id = model.decoder.config.bos_token_id
736
+ >>> model.config.pad_token_id = model.decoder.config.pad_token_id
737
+ >>> model.config.eos_token_id = model.decoder.config.eos_token_id
738
+
739
+ >>> outputs = model.generate(inputs)
740
+ # Assert something? More interesting input? dtype correct?
741
+ ```
742
+ """
743
+
744
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
745
+ output_hidden_states = (
746
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
747
+ )
748
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
749
+
750
+ # prepare encoder inputs
751
+ if attention_mask is None:
752
+ attention_mask = jnp.ones_like(inputs, dtype="i4")
753
+
754
+ if extract_features is not None:
755
+ inputs = None # we can omit passing the inputs to the model to save memory
756
+ extract_features = jnp.array(extract_features, dtype="f4")
757
+ else:
758
+ inputs = jnp.array(inputs, dtype="f4")
759
+
760
+ # prepare decoder inputs
761
+ if decoder_input_ids is None:
762
+ raise ValueError(
763
+ "`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must be specified as an input argument."
764
+ )
765
+ if decoder_attention_mask is None:
766
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
767
+ if decoder_position_ids is None:
768
+ batch_size, sequence_length = decoder_input_ids.shape
769
+ decoder_position_ids = jnp.broadcast_to(
770
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
771
+ )
772
+
773
+ # Handle any PRNG if needed
774
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
775
+
776
+ return self.module.apply(
777
+ {"params": params or self.params},
778
+ inputs=inputs,
779
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
780
+ extract_features=extract_features,
781
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
782
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
783
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
784
+ output_attentions=output_attentions,
785
+ output_hidden_states=output_hidden_states,
786
+ output_features=output_features,
787
+ return_dict=return_dict,
788
+ deterministic=not train,
789
+ freeze_feature_encoder=freeze_feature_encoder,
790
+ rngs=rngs,
791
+ )
792
+
793
+ def prepare_inputs_for_generation(
794
+ self,
795
+ decoder_input_ids,
796
+ max_length,
797
+ attention_mask: Optional[jnp.DeviceArray] = None,
798
+ decoder_attention_mask: Optional[jnp.DeviceArray] = None,
799
+ encoder_outputs=None,
800
+ **kwargs
801
+ ):
802
+ # initializing the cache
803
+ batch_size, seq_length = decoder_input_ids.shape
804
+
805
+ past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
806
+ # Note that usually one would have to put 0's in the attention_mask for x > input.shape[-1] and x < cache_length.
807
+ # But since the decoder uses a causal mask, those positions are masked anyways.
808
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
809
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
810
+ if decoder_attention_mask is not None:
811
+ decoder_position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
812
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
813
+ else:
814
+ decoder_position_ids = jnp.broadcast_to(
815
+ jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
816
+ )
817
+
818
+ return {
819
+ "past_key_values": past_key_values,
820
+ "encoder_outputs": encoder_outputs,
821
+ "encoder_attention_mask": attention_mask,
822
+ "decoder_attention_mask": extended_attention_mask,
823
+ "decoder_position_ids": decoder_position_ids,
824
+ }
825
+
826
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
827
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
828
+ model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
829
+ return model_kwargs
830
+
831
+ @classmethod
832
+ def from_encoder_decoder_pretrained(
833
+ cls,
834
+ encoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
835
+ decoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
836
+ *model_args,
837
+ **kwargs
838
+ ) -> FlaxPreTrainedModel:
839
+ r"""
840
+ Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model
841
+ checkpoints.
842
+
843
+ Params:
844
+ encoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*):
845
+ Information necessary to initiate the encoder. Can be either:
846
+
847
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
848
+ Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
849
+ user or organization name, like `dbmdz/bert-base-german-cased`.
850
+ - A path to a *directory* containing model weights saved using
851
+ [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
852
+
853
+ decoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*, defaults to `None`):
854
+ Information necessary to initiate the decoder. Can be either:
855
+
856
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
857
+ Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
858
+ user or organization name, like `dbmdz/bert-base-german-cased`.
859
+ - A path to a *directory* containing model weights saved using
860
+ [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
861
+
862
+ model_args (remaining positional arguments, *optional*):
863
+ All remaning positional arguments will be passed to the underlying model's `__init__` method.
864
+
865
+ kwargs (remaining dictionary of keyword arguments, *optional*):
866
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
867
+ `output_attentions=True`).
868
+
869
+ - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter.
870
+ - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
871
+ - To update the parent model configuration, do not use a prefix for each configuration parameter.
872
+
873
+ Behaves differently depending on whether a `config` is provided or automatically loaded.
874
+
875
+ Example:
876
+
877
+ ```python
878
+ >>> from transformers import FlaxSpeechEncoderDecoderModel
879
+
880
+ >>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized
881
+ >>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
882
+ ... "facebook/wav2vec2-large-lv60", "facebook/bart-large"
883
+ ... )
884
+ >>> # saving model after fine-tuning
885
+ >>> model.save_pretrained("./wav2vec2-2-bart-large")
886
+ >>> # load fine-tuned model
887
+ >>> model = FlaxSpeechEncoderDecoderModel.from_pretrained("./wav2vec2-2-bart-large")
888
+ ```"""
889
+
890
+ kwargs_encoder = {
891
+ argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
892
+ }
893
+
894
+ kwargs_decoder = {
895
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
896
+ }
897
+
898
+ # remove encoder, decoder kwargs from kwargs
899
+ for key in kwargs_encoder.keys():
900
+ del kwargs["encoder_" + key]
901
+ for key in kwargs_decoder.keys():
902
+ del kwargs["decoder_" + key]
903
+
904
+ # Load and initialize the encoder and decoder
905
+ # The distinction between encoder and decoder at the model level is made
906
+ # by the value of the flag `is_decoder` that we need to set correctly.
907
+ encoder = kwargs_encoder.pop("model", None)
908
+ if encoder is None:
909
+ if encoder_pretrained_model_name_or_path is None:
910
+ raise ValueError(
911
+ "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
912
+ "to be defined."
913
+ )
914
+
915
+ if "config" not in kwargs_encoder:
916
+ # TODO: AutoConfig .from_pretrained
917
+ encoder_config, kwargs_encoder = Wav2Vec2Config.from_pretrained(
918
+ encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
919
+ )
920
+ if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
921
+ logger.info(
922
+ f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
923
+ "from a decoder model. Cross-attention and casual mask are disabled."
924
+ )
925
+ encoder_config.is_decoder = False
926
+ encoder_config.add_cross_attention = False
927
+
928
+ kwargs_encoder["config"] = encoder_config
929
+
930
+ # TODO: FlaxAutoModel .from_pretrained
931
+ encoder = FlaxWav2Vec2Model.from_pretrained(
932
+ encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder
933
+ )
934
+
935
+ decoder = kwargs_decoder.pop("model", None)
936
+ if decoder is None:
937
+ if decoder_pretrained_model_name_or_path is None:
938
+ raise ValueError(
939
+ "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
940
+ "to be defined."
941
+ )
942
+
943
+ if "config" not in kwargs_decoder:
944
+ # TODO: AutoConfig .from_pretrained
945
+ decoder_config, kwargs_decoder = BartConfig.from_pretrained(
946
+ decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
947
+ )
948
+ if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
949
+ logger.info(
950
+ f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
951
+ f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
952
+ f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
953
+ "cross attention layers."
954
+ )
955
+ decoder_config.is_decoder = True
956
+ decoder_config.add_cross_attention = True
957
+
958
+ kwargs_decoder["config"] = decoder_config
959
+
960
+ if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
961
+ logger.warning(
962
+ f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
963
+ f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
964
+ "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
965
+ "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
966
+ "`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
967
+ )
968
+
969
+ # TODO: FlaxAutoModelForCausalLM .from_pretrained
970
+ decoder = FlaxBartForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
971
+
972
+ # instantiate config with corresponding kwargs
973
+ dtype = kwargs.pop("dtype", jnp.float32)
974
+ config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
975
+
976
+ # make sure input & output word embeddings are not tied
977
+ config.tie_word_embeddings = False
978
+
979
+ # init model
980
+ model = cls(config, dtype=dtype)
981
+ model.params["encoder"] = encoder.params
982
+ model.params["decoder"] = decoder.params
983
+
984
+ return model
985
+
986
+ def _beam_search(
987
+ self,
988
+ input_ids: None,
989
+ max_length: Optional[int] = None,
990
+ pad_token_id: Optional[int] = None,
991
+ eos_token_id: Optional[int] = None,
992
+ length_penalty: Optional[float] = None,
993
+ early_stopping: Optional[bool] = None,
994
+ logits_processor: Optional[FlaxLogitsProcessorList] = None,
995
+ trace: bool = True,
996
+ params: Optional[Dict[str, jnp.ndarray]] = None,
997
+ model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
998
+ ):
999
+ """
1000
+ This beam search function is heavily inspired by Flax's official example:
1001
+ https://github.com/google/flax/blob/master/examples/wmt/train.py#L254
1002
+ """
1003
+
1004
+ def flatten_beam_dim(tensor):
1005
+ """Flattens the first two dimensions of a non-scalar array."""
1006
+ # ignore scalars (e.g. cache index)
1007
+ if tensor.ndim == 0 or tensor.ndim == 1:
1008
+ return tensor
1009
+ elif tensor.ndim == 6:
1010
+ return tensor.reshape(tensor.shape[:1] + (tensor.shape[1] * tensor.shape[2],) + tensor.shape[3:])
1011
+ return tensor.reshape((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:])
1012
+
1013
+ def unflatten_beam_dim(tensor, batch_size, num_beams):
1014
+ """Unflattens the first, flat batch*beam dimension of a non-scalar array."""
1015
+ # ignore scalars (e.g. cache index)
1016
+ if tensor.ndim == 0 or tensor.ndim == 1:
1017
+ return tensor
1018
+ if tensor.ndim == 5:
1019
+ return tensor.reshape(tensor.shape[:1] + (batch_size, num_beams) + tensor.shape[2:])
1020
+ return tensor.reshape((batch_size, num_beams) + tensor.shape[1:])
1021
+
1022
+ def gather_beams(nested, beam_indices, batch_size, new_num_beams):
1023
+ """
1024
+ Gathers the beam slices indexed by beam_indices into new beam array.
1025
+ """
1026
+ batch_indices = jnp.reshape(
1027
+ jnp.arange(batch_size * new_num_beams) // new_num_beams, (batch_size, new_num_beams)
1028
+ )
1029
+
1030
+ def gather_fn(tensor):
1031
+ # ignore scalars (e.g. cache index)
1032
+ if tensor.ndim == 0 or tensor.ndim == 1:
1033
+ return tensor
1034
+ if tensor.ndim == 6:
1035
+ return tensor[:, batch_indices, beam_indices]
1036
+ return tensor[batch_indices, beam_indices]
1037
+
1038
+ return jax.tree_map(gather_fn, nested)
1039
+
1040
+ # init values
1041
+ max_length = max_length if max_length is not None else self.config.max_length
1042
+ pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
1043
+ eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
1044
+ length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
1045
+ early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
1046
+
1047
+ batch_size, num_beams, cur_len = input_ids.shape
1048
+
1049
+ eos_token_id = jnp.array(eos_token_id)
1050
+ pad_token_id = jnp.array(pad_token_id)
1051
+ cur_len = jnp.array(cur_len)
1052
+
1053
+ # per batch,beam-item holding current token in loop.
1054
+ sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
1055
+ running_sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
1056
+ running_sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0, 0))
1057
+
1058
+ # per batch,beam-item state bit indicating if sentence has finished.
1059
+ is_sent_finished = jnp.zeros((batch_size, num_beams), dtype=jnp.bool_)
1060
+
1061
+ # per batch,beam-item score, logprobs
1062
+ running_scores = jnp.tile(jnp.array([0.0] + [np.array(-1.0e7)] * (num_beams - 1)), [batch_size, 1])
1063
+ scores = jnp.ones((batch_size, num_beams)) * np.array(-1.0e7)
1064
+
1065
+ # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
1066
+ # and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
1067
+ model = self.decode if self.config.is_encoder_decoder else self
1068
+
1069
+ # flatten beam dim
1070
+ if "encoder_outputs" in model_kwargs:
1071
+ model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim(
1072
+ model_kwargs["encoder_outputs"]["last_hidden_state"]
1073
+ )
1074
+ if "attention_mask" in model_kwargs:
1075
+ model_kwargs["attention_mask"] = flatten_beam_dim(model_kwargs["attention_mask"])
1076
+
1077
+ # initialize model specific kwargs
1078
+ model_kwargs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), max_length, **model_kwargs)
1079
+
1080
+ # initialize state
1081
+ state = BeamSearchState(
1082
+ cur_len=cur_len,
1083
+ running_sequences=running_sequences,
1084
+ running_scores=running_scores,
1085
+ sequences=sequences,
1086
+ scores=scores,
1087
+ is_sent_finished=is_sent_finished,
1088
+ model_kwargs=model_kwargs,
1089
+ )
1090
+
1091
+ def beam_search_cond_fn(state):
1092
+ """beam search state termination condition fn."""
1093
+
1094
+ # 1. is less than max length?
1095
+ not_max_length_yet = state.cur_len < max_length
1096
+
1097
+ # 2. can the new beams still improve?
1098
+ best_running_score = state.running_scores[:, -1:] / (max_length**length_penalty)
1099
+ worst_finished_score = jnp.where(
1100
+ state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7)
1101
+ )
1102
+ improvement_still_possible = jnp.all(worst_finished_score < best_running_score)
1103
+
1104
+ # 3. is there still a beam that has not finished?
1105
+ still_open_beam = ~(jnp.all(state.is_sent_finished) & early_stopping)
1106
+
1107
+ return not_max_length_yet & still_open_beam & improvement_still_possible
1108
+
1109
+ def beam_search_body_fn(state, input_ids_length=1):
1110
+ """beam search state update fn."""
1111
+ # 1. Forward current tokens
1112
+ # Collect the current position slice along length to feed the fast
1113
+ # autoregressive decoder model. Flatten the beam dimension into batch
1114
+ # dimension for feeding into the model.
1115
+ # unflatten beam dimension
1116
+ # Unflatten beam dimension in attention cache arrays
1117
+ input_token = flatten_beam_dim(
1118
+ lax.dynamic_slice(
1119
+ state.running_sequences,
1120
+ (0, 0, state.cur_len - input_ids_length),
1121
+ (batch_size, num_beams, input_ids_length),
1122
+ )
1123
+ )
1124
+ model_outputs = model(input_token, params=params, **state.model_kwargs)
1125
+
1126
+ logits = unflatten_beam_dim(model_outputs.logits[:, -1], batch_size, num_beams)
1127
+ cache = jax.tree_map(
1128
+ lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams), model_outputs.past_key_values
1129
+ )
1130
+
1131
+ # adapt logits for FlaxMarianMTModel
1132
+ logits = self._adapt_logits_for_beam_search(logits)
1133
+
1134
+ # 2. Compute log probs
1135
+ # get log probabilities from logits,
1136
+ # process logits with processors (*e.g.* min_length, ...), and
1137
+ # add new logprobs to existing running logprobs scores.
1138
+ log_probs = jax.nn.log_softmax(logits)
1139
+ log_probs = logits_processor(
1140
+ flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), state.cur_len
1141
+ )
1142
+ log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
1143
+ log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2)
1144
+ vocab_size = log_probs.shape[2]
1145
+ log_probs = log_probs.reshape((batch_size, num_beams * vocab_size))
1146
+
1147
+ # 3. Retrieve top-K
1148
+ # Each item in batch has num_beams * vocab_size candidate sequences.
1149
+ # For each item, get the top 2*k candidates with the highest log-
1150
+ # probabilities. We gather the top 2*K beams here so that even if the best
1151
+ # K sequences reach EOS simultaneously, we have another K sequences
1152
+ # remaining to continue the live beam search.
1153
+ # Gather the top 2*K scores from _all_ beams.
1154
+ # Gather 2*k top beams.
1155
+ # Recover the beam index by floor division.
1156
+ # Recover token id by modulo division and expand Id array for broadcasting.
1157
+ # Update sequences for the 2*K top-k new sequences.
1158
+ beams_to_keep = 2 * num_beams
1159
+ topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep)
1160
+ topk_beam_indices = topk_indices // vocab_size
1161
+ topk_running_sequences = gather_beams(
1162
+ state.running_sequences, topk_beam_indices, batch_size, beams_to_keep
1163
+ )
1164
+ topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
1165
+ topk_sequences = lax.dynamic_update_slice(topk_running_sequences, topk_ids, (0, 0, state.cur_len))
1166
+
1167
+ # 4. Check which sequences have ended
1168
+ # Update current sequences:
1169
+ # Did any of these sequences reach an end marker?
1170
+ # To prevent these just finished sequences from being added to the current sequences
1171
+ # set of active beam search sequences, set their log probs to a very large
1172
+ # negative value.
1173
+ did_topk_just_finished = topk_sequences[:, :, state.cur_len] == eos_token_id
1174
+ running_topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(-1.0e7)
1175
+ # 5. Get running sequences scores for next
1176
+ # Determine the top k beam indices (from top 2*k beams) from log probs
1177
+ # and gather top k beams (from top 2*k beams).
1178
+ next_topk_indices = jnp.flip(lax.top_k(running_topk_log_probs, k=num_beams)[1], axis=1)
1179
+ next_running_sequences, next_running_scores = gather_beams(
1180
+ [topk_sequences, running_topk_log_probs], next_topk_indices, batch_size, num_beams
1181
+ )
1182
+
1183
+ # 6. Process topk logits
1184
+ # Further process log probs:
1185
+ # - add length penalty
1186
+ # - make sure no scores can be added anymore if beam is full
1187
+ # - make sure still running sequences cannot be chosen as finalized beam
1188
+ topk_log_probs = topk_log_probs / (state.cur_len**length_penalty)
1189
+ beams_in_batch_are_full = (
1190
+ jnp.broadcast_to(state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape)
1191
+ & early_stopping
1192
+ )
1193
+ add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
1194
+ topk_log_probs += add_penalty * np.array(-1.0e7)
1195
+
1196
+ # 7. Get scores, sequences, is sentence finished for next.
1197
+ # Combine sequences, scores, and flags along the beam dimension and compare
1198
+ # new finished sequence scores to existing finished scores and select the
1199
+ # best from the new set of beams
1200
+ merged_sequences = jnp.concatenate([state.sequences, topk_sequences], axis=1)
1201
+ merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1)
1202
+ merged_is_sent_finished = jnp.concatenate([state.is_sent_finished, did_topk_just_finished], axis=1)
1203
+ topk_merged_indices = jnp.flip(lax.top_k(merged_scores, k=num_beams)[1], axis=1)
1204
+ next_sequences, next_scores, next_is_sent_finished = gather_beams(
1205
+ [merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams
1206
+ )
1207
+
1208
+ # 8. Update model kwargs.
1209
+ # Determine the top k beam indices from the original set of all beams.
1210
+ # With these, gather the top k beam-associated caches.
1211
+ next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams)
1212
+ next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams)
1213
+ model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache)
1214
+ next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
1215
+
1216
+ return BeamSearchState(
1217
+ cur_len=state.cur_len + 1,
1218
+ running_scores=next_running_scores,
1219
+ running_sequences=next_running_sequences,
1220
+ scores=next_scores,
1221
+ sequences=next_sequences,
1222
+ is_sent_finished=next_is_sent_finished,
1223
+ model_kwargs=next_model_kwargs,
1224
+ )
1225
+
1226
+ # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
1227
+ if input_ids.shape[-1] > 1:
1228
+ state = partial(beam_search_body_fn, input_ids_length=input_ids.shape[-1])(state)
1229
+
1230
+ if not trace:
1231
+ state = self._run_loop_in_debug(beam_search_cond_fn, beam_search_body_fn, state)
1232
+ else:
1233
+ state = lax.while_loop(beam_search_cond_fn, beam_search_body_fn, state)
1234
+
1235
+ # Account for the edge-case where there are no finished sequences for a
1236
+ # particular batch item. If so, return running sequences for that batch item.
1237
+ none_finished = jnp.any(state.is_sent_finished, axis=1)
1238
+ sequences = jnp.where(none_finished[:, None, None], state.sequences, state.running_sequences)
1239
+ scores = jnp.where(none_finished[:, None], state.scores, state.running_scores)
1240
+
1241
+ # return all beams for each batch and the best score
1242
+ sequences = sequences[:, :]
1243
+ scores = scores[:, -1]
1244
+
1245
+ return FlaxBeamSearchOutput(sequences=sequences, scores=scores)
models/modeling_flax_wav2vec2.py ADDED
@@ -0,0 +1,975 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Flax Wav2Vec2 model."""
16
+
17
+ from functools import partial
18
+ from typing import Optional, Tuple, Union
19
+
20
+ import flax
21
+ import flax.linen as nn
22
+ import jax
23
+ import jax.numpy as jnp
24
+ from flax.core.frozen_dict import FrozenDict
25
+ from flax.linen import partitioning as nn_partitioning
26
+ from flax.linen.attention import dot_product_attention_weights
27
+ from jax import lax
28
+
29
+ from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
30
+ from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
31
+ from transformers.utils import ModelOutput
32
+
33
+ from models import Wav2Vec2Config
34
+
35
+ scan_with_axes = nn_partitioning.scan_with_axes
36
+ remat = nn_partitioning.remat
37
+
38
+
39
+ @flax.struct.dataclass
40
+ class FlaxWav2Vec2BaseModelOutput(ModelOutput):
41
+ """
42
+ Output type of [`FlaxWav2Vec2BaseModelOutput`], with potential hidden states and attentions.
43
+
44
+ Args:
45
+ last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
46
+ Sequence of hidden-states at the output of the last layer of the model.
47
+ extract_features (`jnp.ndarray` of shape `(batch_size, sequence_length, last_conv_dim)`):
48
+ Sequence of extracted feature vectors of the last convolutional layer of the model with `last_conv_dim`
49
+ being the dimension of the last convolutional layer.
50
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
51
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
52
+ `(batch_size, sequence_length, hidden_size)`.
53
+
54
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
55
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
56
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
57
+ sequence_length)`.
58
+
59
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
60
+ heads.
61
+ """
62
+
63
+ last_hidden_state: jnp.ndarray = None
64
+ extract_features: jnp.ndarray = None
65
+ hidden_states: Optional[Tuple[jnp.ndarray]] = None
66
+ attentions: Optional[Tuple[jnp.ndarray]] = None
67
+
68
+
69
+ WAV_2_VEC_2_START_DOCSTRING = r"""
70
+ Wav2Vec2 was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
71
+ Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
72
+ Auli.
73
+
74
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
75
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
76
+ etc.)
77
+
78
+ This model is also a Flax Linen
79
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
80
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
81
+
82
+ Finally, this model supports inherent JAX features such as:
83
+
84
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
85
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
86
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
87
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
88
+
89
+ Parameters:
90
+ config ([`Wav2Vec2Config`]): Model configuration class with all the parameters of the model.
91
+ Initializing with a config file does not load the weights associated with the model, only the
92
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
93
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
94
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
95
+ `jax.numpy.bfloat16` (on TPUs).
96
+
97
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
98
+ specified all the computation will be performed with the given `dtype`.
99
+
100
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
101
+ parameters.**
102
+
103
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
104
+ [`~FlaxPreTrainedModel.to_bf16`].
105
+ """
106
+
107
+
108
+ WAV_2_VEC_2_INPUTS_DOCSTRING = r"""
109
+ Args:
110
+ input_values (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
111
+ Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file
112
+ into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install
113
+ soundfile*). To prepare the array into *input_values*, the [`Wav2Vec2Processor`] should be used for padding
114
+ and conversion into a tensor of type *jnp.ndarray*. See [`Wav2Vec2Processor.__call__`] for details.
115
+ attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
116
+ Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
117
+ 1]`:
118
+
119
+ - 1 for tokens that are **not masked**,
120
+ - 0 for tokens that are **masked**.
121
+
122
+ [What are attention masks?](../glossary#attention-mask) .. warning:: `attention_mask` should only be passed
123
+ if the corresponding processor has `config.return_attention_mask == True`. For all models whose processor
124
+ has `config.return_attention_mask == False`, such as
125
+ [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base-960h), `attention_mask` should **not** be
126
+ passed to avoid degraded performance when doing batched inference. For such models `input_values` should
127
+ simply be padded with 0 and passed without `attention_mask`. Be aware that these models also yield slightly
128
+ different results depending on whether `input_values` is padded or not.
129
+ mask_time_indices (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
130
+ Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
131
+ masked extracted features in *config.proj_codevector_dim* space.
132
+ output_attentions (`bool`, *optional*):
133
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
134
+ tensors for more detail.
135
+ output_hidden_states (`bool`, *optional*):
136
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
137
+ more detail.
138
+ return_dict (`bool`, *optional*):
139
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
140
+ """
141
+
142
+
143
+ class FlaxWav2Vec2LayerNormConvLayer(nn.Module):
144
+ config: Wav2Vec2Config
145
+ layer_id: int = 0
146
+ dtype: jnp.dtype = jnp.float32
147
+
148
+ def setup(self):
149
+ self.in_conv_dim = self.config.conv_dim[self.layer_id] if self.layer_id > 0 else 1
150
+ self.out_conv_dim = self.config.conv_dim[self.layer_id]
151
+
152
+ self.conv = nn.Conv(
153
+ features=self.config.conv_dim[self.layer_id],
154
+ kernel_size=(self.config.conv_kernel[self.layer_id],),
155
+ strides=(self.config.conv_stride[self.layer_id],),
156
+ use_bias=self.config.conv_bias,
157
+ kernel_init=jax.nn.initializers.he_normal(),
158
+ padding="VALID",
159
+ dtype=self.dtype,
160
+ )
161
+ self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
162
+ self.activation = ACT2FN[self.config.feat_extract_activation]
163
+
164
+ def __call__(self, hidden_states):
165
+ hidden_states = self.conv(hidden_states)
166
+ hidden_states = self.layer_norm(hidden_states)
167
+ hidden_states = self.activation(hidden_states)
168
+ return hidden_states
169
+
170
+
171
+ class FlaxConvWithWeightNorm(nn.Module):
172
+ config: Wav2Vec2Config
173
+ dtype: jnp.dtype = jnp.float32
174
+
175
+ def setup(self):
176
+ self.conv = nn.Conv(
177
+ features=self.config.hidden_size,
178
+ kernel_size=(self.config.num_conv_pos_embeddings,),
179
+ kernel_init=jax.nn.initializers.he_normal(),
180
+ padding="VALID",
181
+ feature_group_count=self.config.num_conv_pos_embedding_groups,
182
+ dtype=self.dtype,
183
+ )
184
+ weight_shape = (
185
+ self.conv.features,
186
+ self.conv.features // self.conv.feature_group_count,
187
+ self.conv.kernel_size[0],
188
+ )
189
+ self.weight_v = self.param("weight_v", jax.nn.initializers.he_normal(), weight_shape)
190
+ self.weight_g = self.param("weight_g", lambda _: jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :])
191
+ self.bias = self.param("bias", jax.nn.initializers.zeros, (self.conv.features,))
192
+ self.prev_padding = self.conv.kernel_size[0] // 2
193
+
194
+ def _get_normed_weights(self):
195
+ weight_v_norm = jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :]
196
+ normed_weight_v = jnp.divide(self.weight_v, weight_v_norm)
197
+ normed_kernel = jnp.multiply(normed_weight_v, self.weight_g)
198
+ return normed_kernel
199
+
200
+ def __call__(self, hidden_states):
201
+ kernel = self._get_normed_weights()
202
+ hidden_states = jnp.pad(hidden_states, ((0, 0), (self.prev_padding, self.prev_padding), (0, 0)))
203
+ hidden_states = self.conv.apply({"params": {"kernel": kernel.T, "bias": self.bias}}, hidden_states)
204
+ return hidden_states
205
+
206
+
207
+ class FlaxWav2Vec2PositionalConvEmbedding(nn.Module):
208
+ config: Wav2Vec2Config
209
+ dtype: jnp.dtype = jnp.float32
210
+
211
+ def setup(self):
212
+ self.conv = FlaxConvWithWeightNorm(self.config, dtype=self.dtype)
213
+ self.activation = ACT2FN[self.config.feat_extract_activation]
214
+ self.num_pad_remove = 1 if self.config.num_conv_pos_embeddings % 2 == 0 else 0
215
+
216
+ def __call__(self, hidden_states):
217
+ hidden_states = hidden_states.transpose((0, 1, 2))
218
+
219
+ hidden_states = self.conv(hidden_states)
220
+
221
+ if self.num_pad_remove > 0:
222
+ hidden_states = hidden_states[:, : -self.num_pad_remove, :]
223
+ hidden_states = self.activation(hidden_states)
224
+
225
+ hidden_states = hidden_states.transpose((0, 1, 2))
226
+ return hidden_states
227
+
228
+
229
+ class FlaxConvLayersCollection(nn.Module):
230
+ config: Wav2Vec2Config
231
+ dtype: jnp.dtype = jnp.float32
232
+
233
+ def setup(self):
234
+ if self.config.feat_extract_norm == "layer":
235
+ # note that we can't use scan on the conv layers as they differ on a layer-by-layer basis
236
+ BlockLayer = remat(FlaxWav2Vec2LayerNormConvLayer) if self.config.gradient_checkpointing else FlaxWav2Vec2LayerNormConvLayer
237
+ self.layers = [
238
+ BlockLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype)
239
+ for i in range(self.config.num_feat_extract_layers)
240
+ ]
241
+ elif self.config.feat_extract_norm == "group":
242
+ raise NotImplementedError("At the moment only ``config.feat_extact_norm == 'layer'`` is supported")
243
+ else:
244
+ raise ValueError(
245
+ f"`config.feat_extract_norm` is {self.config.feat_extract_norm}, but has to be one of ['group', 'layer']"
246
+ )
247
+
248
+ def __call__(self, hidden_states):
249
+ for i, conv_layer in enumerate(self.layers):
250
+ hidden_states = conv_layer(hidden_states)
251
+ return hidden_states
252
+
253
+
254
+ class FlaxWav2Vec2FeatureEncoder(nn.Module):
255
+ """Construct the features from raw audio waveform"""
256
+
257
+ config: Wav2Vec2Config
258
+ dtype: jnp.dtype = jnp.float32
259
+
260
+ def setup(self):
261
+ self.conv_layers = FlaxConvLayersCollection(self.config, dtype=self.dtype)
262
+
263
+ def __call__(self, input_values, freeze_feature_encoder=False):
264
+ hidden_states = input_values[:, :, None]
265
+ hidden_states = self.conv_layers(hidden_states)
266
+ if freeze_feature_encoder:
267
+ hidden_states = jax.lax.stop_gradient(hidden_states)
268
+ return hidden_states
269
+
270
+
271
+ class FlaxWav2Vec2FeatureProjection(nn.Module):
272
+ config: Wav2Vec2Config
273
+ dtype: jnp.dtype = jnp.float32
274
+
275
+ def setup(self):
276
+ self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
277
+ self.projection = nn.Dense(
278
+ self.config.hidden_size,
279
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
280
+ dtype=self.dtype,
281
+ )
282
+ self.dropout = nn.Dropout(rate=self.config.feat_proj_dropout)
283
+
284
+ def __call__(self, hidden_states, deterministic=True):
285
+ norm_hidden_states = self.layer_norm(hidden_states)
286
+ hidden_states = self.projection(norm_hidden_states)
287
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
288
+ return hidden_states, norm_hidden_states
289
+
290
+
291
+ class FlaxWav2Vec2Attention(nn.Module):
292
+ config: Wav2Vec2Config
293
+ embed_dim: int
294
+ num_heads: int
295
+ dropout: float = 0.0
296
+ bias: bool = True
297
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
298
+
299
+ def setup(self) -> None:
300
+ self.head_dim = self.embed_dim // self.num_heads
301
+ if self.head_dim * self.num_heads != self.embed_dim:
302
+ raise ValueError(
303
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
304
+ )
305
+
306
+ dense = partial(
307
+ nn.Dense,
308
+ self.embed_dim,
309
+ use_bias=self.bias,
310
+ dtype=self.dtype,
311
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
312
+ )
313
+
314
+ self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
315
+
316
+ self.fused_proj = nn.Dense(
317
+ self.embed_dim * 3,
318
+ use_bias=self.bias,
319
+ dtype=self.dtype,
320
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
321
+ )
322
+
323
+ self.out_proj = dense()
324
+
325
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
326
+
327
+ def _split_heads(self, hidden_states):
328
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
329
+
330
+ def _merge_heads(self, hidden_states):
331
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
332
+
333
+ def __call__(
334
+ self,
335
+ hidden_states: jnp.ndarray,
336
+ key_value_states: Optional[jnp.ndarray] = None,
337
+ attention_mask: Optional[jnp.ndarray] = None,
338
+ deterministic: bool = True,
339
+ ) -> Tuple[jnp.ndarray]:
340
+ """Input shape: Batch x Time x Channel"""
341
+
342
+ if self.config.fuse_matmuls:
343
+ attention_states = self.fused_proj(hidden_states)
344
+ query_states, key_states, value_states = jnp.split(attention_states, 3, axis=-1)
345
+
346
+ else:
347
+ # get query proj
348
+ query_states = self.q_proj(hidden_states)
349
+
350
+ key_states = self.k_proj(hidden_states)
351
+ value_states = self.v_proj(hidden_states)
352
+
353
+ query_states = self._split_heads(query_states)
354
+ key_states = self._split_heads(key_states)
355
+ value_states = self._split_heads(value_states)
356
+
357
+ if attention_mask is not None:
358
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
359
+
360
+ # Convert the boolean attention mask to an attention bias.
361
+ if attention_mask is not None:
362
+ # attention mask in the form of attention bias
363
+ attention_bias = lax.select(
364
+ attention_mask > 0,
365
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
366
+ jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
367
+ )
368
+ else:
369
+ attention_bias = None
370
+
371
+ dropout_rng = None
372
+ if not deterministic and self.dropout > 0.0:
373
+ dropout_rng = self.make_rng("dropout")
374
+
375
+ attn_weights = dot_product_attention_weights(
376
+ query_states,
377
+ key_states,
378
+ bias=attention_bias,
379
+ dropout_rng=dropout_rng,
380
+ dropout_rate=self.dropout,
381
+ broadcast_dropout=True,
382
+ deterministic=deterministic,
383
+ dtype=self.dtype,
384
+ precision=None,
385
+ )
386
+
387
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
388
+ attn_output = self._merge_heads(attn_output)
389
+ attn_output = self.out_proj(attn_output)
390
+
391
+ return attn_output, attn_weights
392
+
393
+
394
+ class FlaxWav2Vec2FeedForward(nn.Module):
395
+ config: Wav2Vec2Config
396
+ dtype: jnp.dtype = jnp.float32
397
+
398
+ def setup(self):
399
+ self.intermediate_dropout = nn.Dropout(rate=self.config.activation_dropout)
400
+
401
+ self.intermediate_dense = nn.Dense(
402
+ self.config.intermediate_size,
403
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
404
+ dtype=self.dtype,
405
+ )
406
+ if isinstance(self.config.hidden_act, str):
407
+ self.intermediate_act_fn = ACT2FN[self.config.hidden_act]
408
+ else:
409
+ self.intermediate_act_fn = self.config.hidden_act
410
+
411
+ self.output_dense = nn.Dense(
412
+ self.config.hidden_size,
413
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
414
+ dtype=self.dtype,
415
+ )
416
+ self.output_dropout = nn.Dropout(rate=self.config.hidden_dropout)
417
+
418
+ def __call__(self, hidden_states, deterministic=True):
419
+ hidden_states = self.intermediate_dense(hidden_states)
420
+ hidden_states = self.intermediate_act_fn(hidden_states)
421
+ hidden_states = self.intermediate_dropout(hidden_states, deterministic=deterministic)
422
+
423
+ hidden_states = self.output_dense(hidden_states)
424
+ hidden_states = self.output_dropout(hidden_states, deterministic=deterministic)
425
+ return hidden_states
426
+
427
+
428
+ class FlaxWav2Vec2EncoderLayerStableLayerNorm(nn.Module):
429
+ config: Wav2Vec2Config
430
+ dtype: jnp.dtype = jnp.float32
431
+
432
+ def setup(self):
433
+ self.attention = FlaxWav2Vec2Attention(
434
+ config=self.config,
435
+ embed_dim=self.config.hidden_size,
436
+ num_heads=self.config.num_attention_heads,
437
+ dropout=self.config.attention_dropout,
438
+ dtype=self.dtype,
439
+ )
440
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout)
441
+ self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
442
+ self.feed_forward = FlaxWav2Vec2FeedForward(self.config, dtype=self.dtype)
443
+ self.final_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
444
+
445
+ def __call__(self, hidden_states, attention_mask=None, deterministic=True, output_attentions=False):
446
+ if self.config.use_scan:
447
+ hidden_states = hidden_states[0]
448
+ attn_residual = hidden_states
449
+ hidden_states = self.layer_norm(hidden_states)
450
+ hidden_states, attn_weights = self.attention(
451
+ hidden_states, attention_mask=attention_mask, deterministic=deterministic
452
+ )
453
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
454
+ hidden_states = attn_residual + hidden_states
455
+ hidden_states = hidden_states + self.feed_forward(
456
+ self.final_layer_norm(hidden_states), deterministic=deterministic
457
+ )
458
+
459
+ outputs = (hidden_states,)
460
+
461
+ if output_attentions:
462
+ outputs += (attn_weights,)
463
+
464
+ if self.config.use_scan:
465
+ outputs = (outputs, None)
466
+
467
+ return outputs
468
+
469
+
470
+ class FlaxWav2Vec2EncoderLayerStableLayerNormCollection(nn.Module):
471
+ config: Wav2Vec2Config
472
+ dtype: jnp.dtype = jnp.float32
473
+
474
+ @nn.compact
475
+ def __call__(
476
+ self,
477
+ hidden_states,
478
+ attention_mask=None,
479
+ deterministic: bool = True,
480
+ output_attentions: bool = False,
481
+ output_hidden_states: bool = False,
482
+ return_dict: bool = True,
483
+ ):
484
+ all_attentions = () if output_attentions else None
485
+ all_hidden_states = () if output_hidden_states else None
486
+
487
+ num_layers = self.config.num_hidden_layers
488
+ BlockEncoderLayer = (
489
+ remat(
490
+ FlaxWav2Vec2EncoderLayerStableLayerNorm,
491
+ static_argnums=(2, 3),
492
+ prevent_cse=not self.config.use_scan,
493
+ )
494
+ if self.config.gradient_checkpointing
495
+ else FlaxWav2Vec2EncoderLayerStableLayerNorm
496
+ )
497
+
498
+ if self.config.use_scan:
499
+ # since all decoder layers are the same, we use nn.scan directly
500
+ assert not output_attentions, "cannot use `scan` with `output_attentions` set to `True`"
501
+ assert not output_hidden_states, "cannot use `scan` with `output_hidden_states` set to `True`"
502
+ hidden_states = (hidden_states,)
503
+
504
+ hidden_states, _ = scan_with_axes(
505
+ BlockEncoderLayer,
506
+ variable_axes={"params": 0, "cache": 0},
507
+ split_rngs={"params": True, "dropout": True},
508
+ in_axes=(nn.broadcast, nn.broadcast, nn.broadcast),
509
+ length=num_layers,
510
+ )(self.config, dtype=self.dtype, name="FlaxWav2Vec2EncoderLayers",)(
511
+ hidden_states, attention_mask, deterministic, output_attentions
512
+ )
513
+ hidden_states = hidden_states[0]
514
+
515
+ else:
516
+ for layer in range(num_layers):
517
+ if output_hidden_states:
518
+ all_hidden_states += (hidden_states,)
519
+
520
+ layer_outputs = BlockEncoderLayer(
521
+ self.config,
522
+ dtype=self.dtype,
523
+ name=str(layer),
524
+ )(hidden_states, attention_mask, deterministic, output_attentions)
525
+
526
+ hidden_states = layer_outputs[0]
527
+
528
+ if output_attentions:
529
+ all_attentions += (layer_outputs[1],)
530
+
531
+ if output_hidden_states:
532
+ all_hidden_states += (hidden_states,)
533
+
534
+ outputs = (hidden_states, all_hidden_states, all_attentions)
535
+
536
+ if not return_dict:
537
+ return tuple(v for v in outputs if v is not None)
538
+
539
+ return FlaxBaseModelOutput(
540
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
541
+ )
542
+
543
+
544
+ class FlaxWav2Vec2StableLayerNormEncoder(nn.Module):
545
+ config: Wav2Vec2Config
546
+ dtype: jnp.dtype = jnp.float32
547
+
548
+ def setup(self):
549
+ self.pos_conv_embed = FlaxWav2Vec2PositionalConvEmbedding(self.config, dtype=self.dtype)
550
+ self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
551
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout)
552
+ self.layers = FlaxWav2Vec2EncoderLayerStableLayerNormCollection(self.config, dtype=self.dtype)
553
+
554
+ def __call__(
555
+ self,
556
+ hidden_states,
557
+ attention_mask=None,
558
+ deterministic=True,
559
+ output_attentions=False,
560
+ output_hidden_states=False,
561
+ return_dict=True,
562
+ ):
563
+
564
+ if attention_mask is not None:
565
+ # make sure padded tokens are not attended to
566
+ hidden_states = jnp.where(
567
+ jnp.broadcast_to(attention_mask[:, :, None], hidden_states.shape), hidden_states, 0
568
+ )
569
+
570
+ position_embeddings = self.pos_conv_embed(hidden_states)
571
+
572
+ hidden_states = hidden_states + position_embeddings
573
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
574
+
575
+ outputs = self.layers(
576
+ hidden_states,
577
+ attention_mask,
578
+ output_attentions=output_attentions,
579
+ output_hidden_states=output_hidden_states,
580
+ return_dict=return_dict,
581
+ )
582
+
583
+ last_hidden_state = self.layer_norm(outputs[0])
584
+
585
+ # update the last element in `hidden_states` after applying `layernorm` above
586
+ hidden_states = None
587
+ if output_hidden_states:
588
+ hidden_states = outputs[1]
589
+ hidden_states = hidden_states[:-1] + (last_hidden_state,)
590
+
591
+ if not return_dict:
592
+ outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
593
+ return tuple(v for v in outputs if v is not None)
594
+
595
+ return FlaxBaseModelOutput(
596
+ last_hidden_state=last_hidden_state, hidden_states=hidden_states, attentions=outputs.attentions
597
+ )
598
+
599
+
600
+ class FlaxWav2Vec2Adapter(nn.Module):
601
+ config: Wav2Vec2Config
602
+ dtype: jnp.dtype = jnp.float32
603
+
604
+ def setup(self):
605
+ # hidden_states require down-projection if feature dims don't match
606
+ if self.config.output_hidden_size != self.config.hidden_size:
607
+ self.proj = nn.Dense(
608
+ self.config.output_hidden_size,
609
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
610
+ dtype=self.dtype,
611
+ )
612
+ self.proj_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
613
+ else:
614
+ self.proj = self.proj_layer_norm = None
615
+
616
+ self.layers = FlaxWav2Vec2AdapterLayersCollection(self.config, dtype=self.dtype)
617
+
618
+ def __call__(self, hidden_states, deterministic=True):
619
+ # down-project hidden_states if required
620
+ if self.proj is not None and self.proj_layer_norm is not None:
621
+ hidden_states = self.proj(hidden_states)
622
+ hidden_states = self.proj_layer_norm(hidden_states)
623
+
624
+ hidden_states = self.layers(hidden_states)
625
+
626
+ return hidden_states
627
+
628
+
629
+ class FlaxWav2Vec2AdapterLayer(nn.Module):
630
+ config: Wav2Vec2Config
631
+ dtype: jnp.dtype = jnp.float32
632
+
633
+ def setup(self):
634
+ self.conv = nn.Conv(
635
+ features=2 * self.config.output_hidden_size,
636
+ kernel_size=(self.config.adapter_kernel_size,),
637
+ strides=(self.config.adapter_stride,),
638
+ padding=((1, 1),),
639
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
640
+ dtype=self.dtype,
641
+ )
642
+
643
+ def __call__(self, hidden_states):
644
+ hidden_states = self.conv(hidden_states)
645
+ hidden_states = nn.glu(hidden_states, axis=2)
646
+
647
+ return hidden_states
648
+
649
+
650
+ class FlaxWav2Vec2AdapterLayersCollection(nn.Module):
651
+ config: Wav2Vec2Config
652
+ dtype: jnp.dtype = jnp.float32
653
+
654
+ def setup(self):
655
+ BlockAdapterLayer = remat(FlaxWav2Vec2AdapterLayer) if self.config.gradient_checkpointing else FlaxWav2Vec2AdapterLayer
656
+ self.layers = [
657
+ BlockAdapterLayer(self.config, name=str(i), dtype=self.dtype)
658
+ for i in range(self.config.num_adapter_layers)
659
+ ]
660
+
661
+ def __call__(self, hidden_states):
662
+ for conv_layer in self.layers:
663
+ hidden_states = conv_layer(hidden_states)
664
+
665
+ return hidden_states
666
+
667
+
668
+ class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
669
+ """
670
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
671
+ models.
672
+ """
673
+
674
+ config_class = Wav2Vec2Config
675
+ base_model_prefix: str = "wav2vec2"
676
+ main_input_name = "input_values"
677
+ module_class: nn.Module = None
678
+
679
+ def __init__(
680
+ self,
681
+ config: Wav2Vec2Config,
682
+ input_shape: Tuple = (1, 1024),
683
+ seed: int = 0,
684
+ dtype: jnp.dtype = jnp.float32,
685
+ _do_init: bool = True,
686
+ **kwargs,
687
+ ):
688
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
689
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
690
+
691
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
692
+ # init input tensors
693
+ input_values = jnp.zeros(input_shape, dtype="i4")
694
+ attention_mask = jnp.ones_like(input_values)
695
+ params_rng, dropout_rng = jax.random.split(rng, 2)
696
+ rngs = {"params": params_rng, "dropout": dropout_rng}
697
+
698
+ return self.module.init(rngs, input_values, attention_mask, return_dict=False)["params"]
699
+
700
+ def __call__(
701
+ self,
702
+ input_values,
703
+ attention_mask=None,
704
+ mask_time_indices=None,
705
+ extract_features=None,
706
+ params: dict = None,
707
+ dropout_rng: jax.random.PRNGKey = None,
708
+ train: bool = False,
709
+ output_attentions: Optional[bool] = None,
710
+ output_hidden_states: Optional[bool] = None,
711
+ output_features: Optional[bool] = None,
712
+ freeze_feature_encoder: bool = False,
713
+ return_dict: Optional[bool] = None,
714
+ ):
715
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
716
+ output_hidden_states = (
717
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
718
+ )
719
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
720
+
721
+ if attention_mask is None:
722
+ batch_size, sequence_length = input_values.shape
723
+ attention_mask = jnp.ones((batch_size, sequence_length))
724
+
725
+ if extract_features is not None:
726
+ extract_features = jnp.array(extract_features, dtype="f4")
727
+
728
+ # Handle any PRNG if needed
729
+ rngs = {}
730
+ if dropout_rng is not None:
731
+ rngs["dropout"] = dropout_rng
732
+
733
+ inputs = {"params": params or self.params}
734
+
735
+ return self.module.apply(
736
+ inputs,
737
+ jnp.array(input_values, dtype="f4"),
738
+ jnp.array(attention_mask, dtype="i4"),
739
+ mask_time_indices,
740
+ extract_features,
741
+ not train,
742
+ output_attentions,
743
+ output_hidden_states,
744
+ output_features,
745
+ freeze_feature_encoder,
746
+ return_dict,
747
+ rngs=rngs,
748
+ )
749
+
750
+ def _get_feat_extract_output_lengths(
751
+ self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
752
+ ):
753
+ return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter)
754
+
755
+ def _get_feature_vector_attention_mask(
756
+ self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
757
+ ):
758
+ return self.module._get_feature_vector_attention_mask(feature_vector_length, attention_mask, add_adapter=add_adapter)
759
+
760
+
761
+ class FlaxWav2Vec2Module(nn.Module):
762
+ config: Wav2Vec2Config
763
+ dtype: jnp.dtype = jnp.float32
764
+
765
+ def setup(self):
766
+ self.feature_extractor = FlaxWav2Vec2FeatureEncoder(self.config, dtype=self.dtype)
767
+ self.feature_projection = FlaxWav2Vec2FeatureProjection(self.config, dtype=self.dtype)
768
+ self.masked_spec_embed = self.param(
769
+ "masked_spec_embed", jax.nn.initializers.uniform(), (self.config.hidden_size,)
770
+ )
771
+
772
+ if self.config.do_stable_layer_norm:
773
+ self.encoder = FlaxWav2Vec2StableLayerNormEncoder(self.config, dtype=self.dtype)
774
+ else:
775
+ raise NotImplementedError("``config.do_stable_layer_norm is False`` is currently not supported.")
776
+
777
+ self.adapter = FlaxWav2Vec2Adapter(self.config, dtype=self.dtype) if self.config.add_adapter else None
778
+
779
+ def __call__(
780
+ self,
781
+ input_values,
782
+ attention_mask=None,
783
+ mask_time_indices=None,
784
+ extract_features=None,
785
+ deterministic=True,
786
+ output_attentions=None,
787
+ output_hidden_states=None,
788
+ output_features=False,
789
+ freeze_feature_encoder=False,
790
+ return_dict=None,
791
+ ):
792
+
793
+ # forward pass through the feature extractor if features not specified
794
+ if extract_features is None:
795
+ extract_features = self.feature_extractor(input_values, freeze_feature_encoder=freeze_feature_encoder)
796
+
797
+ if output_features:
798
+ return extract_features
799
+
800
+ # make sure that no loss is computed on padded inputs
801
+ if attention_mask is not None:
802
+ # compute reduced attention_mask corresponding to feature vectors
803
+ attention_mask = self._get_feature_vector_attention_mask(
804
+ extract_features.shape[1], attention_mask, add_adapter=False
805
+ )
806
+
807
+ hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic)
808
+ if mask_time_indices is not None: # apply SpecAugment along time axis with given indices
809
+ hidden_states = jnp.where(
810
+ jnp.broadcast_to(mask_time_indices[:, :, None], hidden_states.shape),
811
+ jnp.broadcast_to(self.masked_spec_embed[None, None, :], hidden_states.shape),
812
+ hidden_states,
813
+ )
814
+
815
+ encoder_outputs = self.encoder(
816
+ hidden_states,
817
+ attention_mask=attention_mask,
818
+ deterministic=deterministic,
819
+ output_attentions=output_attentions,
820
+ output_hidden_states=output_hidden_states,
821
+ return_dict=return_dict,
822
+ )
823
+
824
+ hidden_states = encoder_outputs[0]
825
+
826
+ if self.adapter is not None:
827
+ hidden_states = self.adapter(hidden_states)
828
+
829
+ if not return_dict:
830
+ return (hidden_states, extract_features) + encoder_outputs[1:]
831
+
832
+ return FlaxWav2Vec2BaseModelOutput(
833
+ last_hidden_state=hidden_states,
834
+ extract_features=extract_features,
835
+ hidden_states=encoder_outputs.hidden_states,
836
+ attentions=encoder_outputs.attentions,
837
+ )
838
+
839
+ def _get_feat_extract_output_lengths(
840
+ self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
841
+ ):
842
+ """
843
+ Computes the output length of the convolutional layers
844
+ """
845
+
846
+ add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
847
+
848
+ def _conv_out_length(input_length, kernel_size, stride):
849
+ # 1D convolutional layer output length formula taken
850
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
851
+ return (input_length - kernel_size) // stride + 1
852
+
853
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
854
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
855
+
856
+ if add_adapter:
857
+ for _ in range(self.config.num_adapter_layers):
858
+ input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
859
+
860
+ return input_lengths
861
+
862
+ def _get_feature_vector_attention_mask(
863
+ self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
864
+ ):
865
+ # Effectively attention_mask.sum(-1), but not inplace to be able to run
866
+ # on inference mode.
867
+ non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1]
868
+
869
+ output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
870
+
871
+ batch_size = attention_mask.shape[0]
872
+
873
+ attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
874
+ # these two operations makes sure that all values
875
+ # before the output lengths indices are attended to
876
+ attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1)
877
+ attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
878
+ return attention_mask
879
+
880
+
881
+ class FlaxWav2Vec2Model(FlaxWav2Vec2PreTrainedModel):
882
+ module_class = FlaxWav2Vec2Module
883
+
884
+
885
+ class FlaxWav2Vec2ForCTCModule(nn.Module):
886
+ config: Wav2Vec2Config
887
+ dtype: jnp.dtype = jnp.float32
888
+
889
+ def setup(self):
890
+ self.wav2vec2 = FlaxWav2Vec2Module(self.config, dtype=self.dtype)
891
+ self.dropout = nn.Dropout(rate=self.config.final_dropout)
892
+ self.lm_head = nn.Dense(
893
+ self.config.vocab_size,
894
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
895
+ dtype=self.dtype,
896
+ )
897
+
898
+ def __call__(
899
+ self,
900
+ input_values,
901
+ attention_mask=None,
902
+ mask_time_indices=None,
903
+ extract_features=None,
904
+ deterministic=True,
905
+ output_attentions=None,
906
+ output_hidden_states=None,
907
+ output_features=False,
908
+ freeze_feature_encoder=False,
909
+ return_dict=None,
910
+ ):
911
+ outputs = self.wav2vec2(
912
+ input_values,
913
+ attention_mask=attention_mask,
914
+ mask_time_indices=mask_time_indices,
915
+ deterministic=deterministic,
916
+ output_attentions=output_attentions,
917
+ output_hidden_states=output_hidden_states,
918
+ freeze_feature_encoder=freeze_feature_encoder,
919
+ return_dict=return_dict,
920
+ )
921
+
922
+ hidden_states = outputs[0]
923
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
924
+
925
+ logits = self.lm_head(hidden_states)
926
+
927
+ if not return_dict:
928
+ return (logits,) + outputs[2:]
929
+
930
+ return FlaxCausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
931
+
932
+ def _get_feat_extract_output_lengths(
933
+ self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
934
+ ):
935
+ """
936
+ Computes the output length of the convolutional layers
937
+ """
938
+
939
+ add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
940
+
941
+ def _conv_out_length(input_length, kernel_size, stride):
942
+ # 1D convolutional layer output length formula taken
943
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
944
+ return (input_length - kernel_size) // stride + 1
945
+
946
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
947
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
948
+
949
+ if add_adapter:
950
+ for _ in range(self.config.num_adapter_layers):
951
+ input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
952
+
953
+ return input_lengths
954
+
955
+ def _get_feature_vector_attention_mask(
956
+ self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
957
+ ):
958
+ # Effectively attention_mask.sum(-1), but not inplace to be able to run
959
+ # on inference mode.
960
+ non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1]
961
+
962
+ output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
963
+
964
+ batch_size = attention_mask.shape[0]
965
+
966
+ attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
967
+ # these two operations makes sure that all values
968
+ # before the output lengths indices are attended to
969
+ attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1)
970
+ attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
971
+ return attention_mask
972
+
973
+
974
+ class FlaxWav2Vec2ForCTC(FlaxWav2Vec2PreTrainedModel):
975
+ module_class = FlaxWav2Vec2ForCTCModule
preprocessor_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "feature_extractor_type": "Wav2Vec2FeatureExtractor",
4
+ "feature_size": 1,
5
+ "padding_side": "right",
6
+ "padding_value": 0,
7
+ "return_attention_mask": true,
8
+ "sampling_rate": 16000
9
+ }
run.sh ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_flax_speech_recognition_ctc.py \
2
+ --model_name_or_path="facebook/wav2vec2-xls-r-1b" \
3
+ --hub_model_id="NbAiLab/wav2vec2-1b-npsc-nst" \
4
+ --tokenizer_name="./" \
5
+ --output_dir="./" \
6
+ --overwrite_output_dir \
7
+ --num_train_epochs="40" \
8
+ --per_device_train_batch_size="12" \
9
+ --per_device_eval_batch_size="12" \
10
+ --gradient_accumulation_steps="1" \
11
+ --precision="full_mixed" \
12
+ --matmul_precision="bfloat16" \
13
+ --learning_rate="1e-4" \
14
+ --warmup_steps="4000" \
15
+ --length_column_name="input_length" \
16
+ --evaluation_strategy="steps" \
17
+ --text_column_name="text" \
18
+ --save_steps="1000" \
19
+ --eval_steps="1000" \
20
+ --logging_steps="100" \
21
+ --layerdrop="0.041" \
22
+ --attention_dropout="0.094" \
23
+ --activation_dropout="0.055" \
24
+ --hidden_dropout="0.047" \
25
+ --save_total_limit="5" \
26
+ --freeze_feature_encoder \
27
+ --feat_proj_dropout="0.04" \
28
+ --mask_time_prob="0.082" \
29
+ --mask_time_length="10" \
30
+ --mask_feature_prob="0.25" \
31
+ --mask_feature_length="64" \
32
+ --gradient_checkpointing \
33
+ --min_duration_in_seconds="0.5" \
34
+ --max_duration_in_seconds="20.0" \
35
+ --use_auth_token \
36
+ --seed="42" \
37
+ --group_by_length \
38
+ --do_train --do_eval \
39
+ --push_to_hub \
40
+ --preprocessing_num_workers="32" \
41
+ --ctc_zero_infinity \
42
+ --do_lower_case \
43
+ --wandb_project="wav2vec2" \
44
+ --wandb_name="wav2vec2-1b-npsc-nst" \
45
+ --remove_punctuation
46
+
47
+
48
+ # --fp16
run_flax_speech_recognition_ctc.py ADDED
@@ -0,0 +1,1604 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2022 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the Flax library models for connectionist temporal classification (CTC) speech recognition.
18
+ """
19
+ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
+
21
+ import logging
22
+ import math
23
+ import os
24
+ import re
25
+ import sys
26
+ import time
27
+ from dataclasses import dataclass, field
28
+ from pathlib import Path
29
+ from typing import Any, Callable, Dict, List, Optional, Union
30
+
31
+ import datasets
32
+ import numpy as np
33
+ from datasets import DatasetDict, load_dataset, load_metric
34
+ from tqdm import tqdm
35
+
36
+ import flax
37
+ import jax
38
+ import jax.numpy as jnp
39
+ import optax
40
+ import transformers
41
+ import wandb as wandb
42
+ from flax import core, jax_utils, struct, traverse_util
43
+ from flax.jax_utils import unreplicate, pad_shard_unpad
44
+ from flax.training.common_utils import get_metrics, shard, shard_prng_key
45
+ from huggingface_hub import Repository
46
+ from models import Wav2Vec2Config, FlaxWav2Vec2ForCTC
47
+ from optax._src import linear_algebra
48
+ from transformers import (
49
+ AutoFeatureExtractor,
50
+ AutoProcessor,
51
+ AutoTokenizer,
52
+ HfArgumentParser,
53
+ TrainingArguments,
54
+ is_tensorboard_available,
55
+ set_seed,
56
+ )
57
+ from transformers.file_utils import get_full_repo_name
58
+ from transformers.utils import check_min_version
59
+ from transformers.utils.versions import require_version
60
+
61
+
62
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
63
+ check_min_version("4.17.0.dev0")
64
+
65
+ require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
66
+
67
+ logger = logging.getLogger(__name__)
68
+
69
+
70
+ @flax.struct.dataclass
71
+ class ModelArguments:
72
+ """
73
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
74
+ """
75
+
76
+ model_name_or_path: str = field(
77
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
78
+ )
79
+ config_name: Optional[str] = field(
80
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
81
+ )
82
+ tokenizer_name: Optional[str] = field(
83
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
84
+ )
85
+ feature_extractor_name: Optional[str] = field(
86
+ default=None, metadata={"help": "feature extractor name or path if not the same as model_name"}
87
+ )
88
+ cache_dir: Optional[str] = field(
89
+ default=None,
90
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
91
+ )
92
+ use_fast_tokenizer: bool = field(
93
+ default=True,
94
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
95
+ )
96
+ model_revision: str = field(
97
+ default="main",
98
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
99
+ )
100
+ use_auth_token: bool = field(
101
+ default=False,
102
+ metadata={
103
+ "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
104
+ "with private models)."
105
+ },
106
+ )
107
+ freeze_feature_encoder: bool = field(
108
+ default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
109
+ )
110
+ attention_dropout: float = field(
111
+ default=0.0, metadata={"help": "The dropout ratio for the attention probabilities."}
112
+ )
113
+ activation_dropout: float = field(
114
+ default=0.1,
115
+ metadata={
116
+ "help": "The hidden activation dropout probability in the embeddings, encoder, and pooler."
117
+ },
118
+ )
119
+ hidden_dropout: float = field(
120
+ default=0.1,
121
+ metadata={
122
+ "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler."
123
+ },
124
+ )
125
+ feat_proj_dropout: float = field(
126
+ default=0.0,
127
+ metadata={
128
+ "help": "The feat proj dropout probability for feature encoder representations."
129
+ },
130
+ )
131
+ final_dropout: float = field(
132
+ default=0.0,
133
+ metadata={"help": "The dropout probability for the final projection layer."},
134
+ )
135
+ mask_time_prob: float = field(
136
+ default=0.1,
137
+ metadata={
138
+ "help": "The spec aug dropout probability for feature encoder representations."
139
+ },
140
+ )
141
+ mask_time_length: int = field(
142
+ default=10,
143
+ metadata={"help": "Length of vector span to mask along the time axis."},
144
+ )
145
+ mask_feature_prob: float = field(
146
+ default=0.0,
147
+ metadata={
148
+ "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector"
149
+ "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis."
150
+ },
151
+ )
152
+ mask_feature_length: int = field(
153
+ default=10,
154
+ metadata={"help": "Length of vector span to mask along the feature axis."},
155
+ )
156
+ layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."})
157
+ ctc_loss_reduction: Optional[str] = field(
158
+ default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
159
+ )
160
+ ctc_zero_infinity: Optional[bool] = field(
161
+ default=False, metadata={"help": "If True, will try yo aboud the CTC loss goinf to infinity."}
162
+ )
163
+
164
+
165
+ @flax.struct.dataclass
166
+ class DataTrainingArguments:
167
+ """
168
+ Arguments pertaining to what data we are going to input our model for training and eval.
169
+ """
170
+
171
+ dataset_name: str = field(
172
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
173
+ )
174
+ dataset_config_name: Optional[str] = field(
175
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
176
+ )
177
+ text_column: Optional[str] = field(
178
+ default=None,
179
+ metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
180
+ )
181
+ dataset_cache_dir: Optional[str] = field(
182
+ default=None, metadata={"help": "Path to cache directory for saving and loading datasets"}
183
+ )
184
+ overwrite_cache: bool = field(
185
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
186
+ )
187
+ preprocessing_num_workers: Optional[int] = field(
188
+ default=None,
189
+ metadata={"help": "The number of processes to use for the preprocessing."},
190
+ )
191
+ max_train_samples: Optional[int] = field(
192
+ default=None,
193
+ metadata={
194
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
195
+ "value if set."
196
+ },
197
+ )
198
+ max_eval_samples: Optional[int] = field(
199
+ default=None,
200
+ metadata={
201
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
202
+ "value if set."
203
+ },
204
+ )
205
+ max_test_samples: Optional[int] = field(
206
+ default=None,
207
+ metadata={
208
+ "help": "For debugging purposes or quicker training, truncate the number of test examples to this "
209
+ "value if set."
210
+ },
211
+ )
212
+ audio_column_name: str = field(
213
+ default="audio",
214
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
215
+ )
216
+ text_column_name: str = field(
217
+ default="text",
218
+ metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
219
+ )
220
+ max_duration_in_seconds: float = field(
221
+ default=20.0,
222
+ metadata={
223
+ "help": "Truncate audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`"
224
+ },
225
+ )
226
+ min_duration_in_seconds: float = field(
227
+ default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
228
+ )
229
+ max_label_length: Optional[int] = field(
230
+ default=512,
231
+ metadata={
232
+ "help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
233
+ "than this will be filtered."
234
+ },
235
+ )
236
+ min_label_length: Optional[int] = field(
237
+ default=2,
238
+ metadata={
239
+ "help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
240
+ "than this will be filtered."
241
+ },
242
+ )
243
+ pad_input_to_multiple_of: Optional[int] = field(
244
+ default=32000,
245
+ metadata={
246
+ "help": "If set will pad the input sequence to a multiple of the provided value. "
247
+ "This is important to avoid triggering recompilations on TPU."
248
+ },
249
+ )
250
+ pad_target_to_multiple_of: Optional[int] = field(
251
+ default=None,
252
+ metadata={
253
+ "help": "If set will pad the target sequence to a multiple of the provided value. "
254
+ "This is important to avoid triggering recompilations on TPU."
255
+ },
256
+ )
257
+ preprocessing_only: bool = field(
258
+ default=False,
259
+ metadata={
260
+ "help": "Whether to only do data preprocessing and skip training. "
261
+ "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
262
+ "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
263
+ "so that the cached datasets can consequently be loaded in distributed training"
264
+ },
265
+ )
266
+ train_split_name: str = field(
267
+ default="train",
268
+ metadata={
269
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
270
+ },
271
+ )
272
+ eval_split_name: str = field(
273
+ default="validation",
274
+ metadata={
275
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
276
+ },
277
+ )
278
+ do_lower_case: bool = field(
279
+ default=True,
280
+ metadata={"help": "Whether the target text should be lower cased."},
281
+ )
282
+ wandb_project: str = field(
283
+ default="flax-speech-recognition-ctc",
284
+ metadata={"help": "The name of the wandb project."},
285
+ )
286
+ wandb_name: str = field(
287
+ default=None,
288
+ metadata={"help": "The name of the wandb run."},
289
+ )
290
+ wandb_job_type: str = field(
291
+ default="CTC",
292
+ metadata={"help": "The name of the wandb job type."},
293
+ )
294
+ test_split_name: str = field(
295
+ default="test",
296
+ metadata={"help": "The name of the test data set split to use (via the datasets library). Defaults to 'test'"},
297
+ )
298
+ remove_punctuation: bool = field(
299
+ default=False, metadata={"help": "Whether or not to remove punctuation during training."}
300
+ )
301
+
302
+
303
+ # @flax.struct.dataclass
304
+ @dataclass
305
+ class FlaxTrainingArguments(TrainingArguments):
306
+ precision: str = field(
307
+ default="full",
308
+ metadata={
309
+ "help": "Whether to enable mixed-precision training. If true, the optimizer is stored in half-precision (bfloat16) and computations are executed in half-precision"
310
+ "**Note that this only specifies the dtype of the computation and optimizer state. It does not influence the dtype of model parameters.**"
311
+ },
312
+ )
313
+ matmul_precision: str = field(
314
+ default="default",
315
+ metadata={
316
+ "help": "Default floating-point precision of internal computations used in TPU matrix multiplications and convolutions. "
317
+ "This configuration option controls the default precision for JAX operations that take an optional precision argument (e.g. `lax.conv_general_dilated` and `lax.dot`). "
318
+ "This configuration option does not change the behaviours of such calls with explicit precision arguments; "
319
+ "it only changes the behaviors of calls with no such argument provided. "
320
+ "One of `['highest', 'float32', 'high', 'bfloat16_3x', 'default', 'bfloat16', 'fastest', None]`."
321
+ },
322
+ )
323
+ multisteps: bool = field(
324
+ default=False,
325
+ metadata={
326
+ "help": "Whether to use Optax MultiSteps for gradient accumulation. If `False` (default) and `gradient_accumulation_steps > 1`, "
327
+ "a custom gradient accumulation implementation will be employed."
328
+ },
329
+ )
330
+
331
+
332
+ def to_fp32(t):
333
+ return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
334
+
335
+
336
+ def to_bf16(t):
337
+ return jax.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t)
338
+
339
+
340
+ class MixedPrecisionTrainState(struct.PyTreeNode):
341
+ """Train state for use with a single Optax optimizer.
342
+ Adapted from flax train_state https://github.com/google/flax/blob/main/flax/training/train_state.py
343
+
344
+ Synopsis::
345
+
346
+ state = TrainState.create(
347
+ apply_fn=model.apply,
348
+ params=variables['params'],
349
+ tx=tx)
350
+ grad_fn = jax.grad(make_loss_fn(state.apply_fn))
351
+ for batch in data:
352
+ grads = grad_fn(state.params, batch)
353
+ state = state.apply_gradients(grads=grads)
354
+
355
+ Args:
356
+ step: Counter starts at 0 and is incremented by every call to
357
+ `.apply_gradients()`.
358
+ apply_fn: Usually set to `model.apply()`. Kept in this dataclass for
359
+ convenience to have a shorter params list for the `train_step()` function
360
+ in your training loop.
361
+ params: The parameters to be updated by `tx` and used by `apply_fn`.
362
+ tx: An Optax gradient transformation.
363
+ opt_state: The state for `tx`.
364
+ dropout_rng: PRNG key for stochastic operations.
365
+ bf16: Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training.
366
+ """
367
+
368
+ step: int
369
+ apply_fn: Callable = struct.field(pytree_node=False)
370
+ get_attention_mask_fn: Callable = struct.field(pytree_node=False)
371
+ params: core.FrozenDict[str, Any]
372
+ tx: optax.GradientTransformation = struct.field(pytree_node=False)
373
+ opt_state: optax.OptState
374
+ dropout_rng: jnp.ndarray
375
+ max_grad_norm: Optional[float] = 1.0
376
+
377
+ def apply_gradients(self, *, grads, to_dtype, **kwargs):
378
+ """Updates `step`, `params`, `opt_state` and `**kwargs` in return value.
379
+
380
+ Note that internally this function calls `.tx.update()` followed by a call
381
+ to `optax.apply_updates()` to update `params` and `opt_state`.
382
+
383
+ Args:
384
+ grads: Gradients that have the same pytree structure as `.params`.
385
+ **kwargs: Additional dataclass attributes that should be `.replace()`-ed.
386
+
387
+ Returns:
388
+ An updated instance of `self` with `step` incremented by one, `params`
389
+ and `opt_state` updated by applying `grads`, and additional attributes
390
+ replaced as specified by `kwargs`.
391
+ """
392
+
393
+ # clip gradients by global l2 norm
394
+ casted_max_grad_norm = to_dtype(self.max_grad_norm)
395
+ g_norm = linear_algebra.global_norm(grads)
396
+ g_norm = jnp.maximum(casted_max_grad_norm, g_norm)
397
+ grads = jax.tree_map(lambda t: (t / g_norm) * casted_max_grad_norm, grads)
398
+
399
+ # perform update step in fp32 and subsequently downcast optimizer states if mixed precision training
400
+ # grads and opt_state in bf16 (need to upcast), params in fp32 (leave as is)
401
+ updates, new_opt_state = self.tx.update(to_fp32(grads), to_fp32(self.opt_state), self.params)
402
+
403
+ new_params = optax.apply_updates(self.params, updates)
404
+ return self.replace(
405
+ step=self.step + 1,
406
+ params=new_params,
407
+ opt_state=to_dtype(new_opt_state),
408
+ **kwargs,
409
+ )
410
+
411
+ @classmethod
412
+ def create(cls, *, apply_fn, params, tx, to_dtype, **kwargs):
413
+ """Creates a new instance with `step=0` and initialized `opt_state`."""
414
+ # downcast optimizer state to bf16 if mixed-precision training
415
+ opt_state = tx.init(to_dtype(params)) if tx is not None else None
416
+ return cls(
417
+ step=0,
418
+ apply_fn=apply_fn,
419
+ params=params,
420
+ tx=tx,
421
+ opt_state=opt_state,
422
+ **kwargs,
423
+ )
424
+
425
+ def replicate(self):
426
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
427
+
428
+
429
+ @flax.struct.dataclass
430
+ class FlaxDataCollatorSpeechSeq2SeqWithPadding:
431
+ """
432
+ Data collator that will dynamically pad the inputs received.
433
+ Args:
434
+ processor ([`Wav2Vec2Processor`])
435
+ The processor used for proccessing the data.
436
+ decoder_start_token_id (:obj: `int`)
437
+ The begin-of-sentence of the decoder.
438
+ input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
439
+ Select a strategy to pad the returned input sequences (according to the model's padding side and padding index)
440
+ among:
441
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
442
+ sequence if provided).
443
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
444
+ maximum acceptable input length for the model if that argument is not provided.
445
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
446
+ different lengths).
447
+ target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
448
+ Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
449
+ See above for details.
450
+ max_input_length (:obj:`float`, `optional`):
451
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
452
+ pad_input_to_multiple_of (:obj:`int`, `optional`):
453
+ If set will pad the input sequence to a multiple of the provided value.
454
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
455
+ 7.5 (Volta).
456
+ pad_target_to_multiple_of (:obj:`int`, `optional`):
457
+ If set will pad the target sequence to a multiple of the provided value.
458
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
459
+ 7.5 (Volta).
460
+ """
461
+
462
+ processor: Any
463
+ input_padding: Union[bool, str] = "longest"
464
+ label_padding: Union[bool, str] = "max_length"
465
+ pad_input_to_multiple_of: Optional[int] = None
466
+ pad_to_multiple_of_label: Optional[int] = None
467
+ max_input_length: Optional[float] = None
468
+ max_label_length: Optional[float] = None
469
+
470
+ def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
471
+ # split inputs and labels since they have to be of different lengths and need
472
+ # different padding methods
473
+ input_features = [{"input_values": feature["input_values"]} for feature in features]
474
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
475
+
476
+ # reformat list to dict and set to pytorch format
477
+ batch = self.processor.feature_extractor.pad(
478
+ input_features,
479
+ max_length=self.max_input_length,
480
+ padding=self.input_padding,
481
+ pad_to_multiple_of=self.pad_input_to_multiple_of,
482
+ return_tensors="np",
483
+ )
484
+
485
+ labels_batch = self.processor.tokenizer.pad(
486
+ label_features,
487
+ max_length=self.max_label_length,
488
+ padding=self.label_padding,
489
+ pad_to_multiple_of=self.pad_to_multiple_of_label,
490
+ return_tensors="np",
491
+ )
492
+
493
+ labels = labels_batch["input_ids"]
494
+ labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1))
495
+ labels = labels.filled(fill_value=-100)
496
+
497
+ batch["labels"] = labels
498
+
499
+ return batch
500
+
501
+
502
+ def get_grouped_indices(
503
+ dataset, batch_size: int, rng: Optional[List[int]] = None, mega_batch_mult: Optional[int] = None
504
+ ) -> np.array:
505
+ """
506
+ Adapted from the `get_length_grouped_indices` function in the PyTorch Trainer utils file (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_pt_utils.py#L486)
507
+ Function that returns a list of indices in which each slice of `batch_size` consecutive indices correspond to elements of similar
508
+ lengths. To do this, the indices are:
509
+
510
+ - randomly permuted (if a JAX rng is specified)
511
+ - grouped in mega-batches of size `mega_batch_mult * batch_size`
512
+ - sorted by length in each mega-batch
513
+
514
+ The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
515
+ maximum length placed first, so that an OOM happens sooner rather than later.
516
+ """
517
+ lengths = dataset["input_length"]
518
+
519
+ # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
520
+ if mega_batch_mult is None:
521
+ mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
522
+ # Just in case, for tiny datasets
523
+ if mega_batch_mult == 0:
524
+ mega_batch_mult = 1
525
+
526
+ # We need to use JAX for the random permutation as the PRNG key will be set based on the seed outside of the sampler.
527
+ num_samples = len(lengths)
528
+ indices = jax.random.permutation(rng, np.arange(num_samples)) if rng is not None else np.arange(num_samples)
529
+
530
+ megabatch_size = mega_batch_mult * batch_size
531
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
532
+ megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches]
533
+
534
+ # The rest is to get the biggest batch first.
535
+ # Since each megabatch is sorted by descending length, the longest element is the first
536
+ megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
537
+ max_idx = np.argmax(megabatch_maximums).item()
538
+ # Switch to put the longest batch in first position
539
+ # (note that this is different to the PT grouped sampler in which we only put the longest element in the first position, and not its batch)
540
+ megabatches[0], megabatches[max_idx] = megabatches[max_idx], megabatches[0]
541
+
542
+ megabatches = np.array([i for megabatch in megabatches for i in megabatch])
543
+
544
+ return megabatches
545
+
546
+
547
+ def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray:
548
+ """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
549
+ the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned."""
550
+ num_samples = len(samples_idx)
551
+ if drop_last:
552
+ samples_to_remove = num_samples % batch_size
553
+ if samples_to_remove != 0:
554
+ samples_idx = samples_idx[:-samples_to_remove]
555
+ sections_split = num_samples // batch_size
556
+ samples_idx = samples_idx.reshape((sections_split, batch_size))
557
+ else:
558
+ sections_split = math.ceil(num_samples / batch_size)
559
+ samples_idx = np.array_split(samples_idx, sections_split)
560
+ return samples_idx
561
+
562
+
563
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
564
+ summary_writer.scalar("train_time", train_time, step)
565
+
566
+ train_metrics = get_metrics(train_metrics)
567
+ for key, vals in train_metrics.items():
568
+ tag = f"train_{key}"
569
+ for i, val in enumerate(vals):
570
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
571
+
572
+
573
+ def write_eval_metric(summary_writer, eval_metrics, step, pred_str=None):
574
+ for metric_name, value in eval_metrics.items():
575
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
576
+
577
+ if pred_str is not None:
578
+ # write output actual predictions for debugging
579
+ summary_writer.text("eval_predictions", "\n".join(pred_str), step)
580
+
581
+
582
+ def write_wandb_log(metrics, step, prefix=None):
583
+ if jax.process_index() == 0:
584
+ log_metrics = {}
585
+ for k, v in metrics.items():
586
+ if "layer" in k:
587
+ log_metrics[f"{k}/"] = v
588
+ elif prefix is not None:
589
+ log_metrics[f"{prefix}/{k}"] = v
590
+ else:
591
+ log_metrics[k] = v
592
+ wandb.log(log_metrics, step)
593
+
594
+
595
+ def write_wandb_pred(pred_str, label_str, step, num_log=50, prefix="eval"):
596
+ if jax.process_index() == 0:
597
+ # convert str data to a wandb compatible format
598
+ str_data = [[label_str[i], pred_str[i]] for i in range(len(pred_str))]
599
+ # we'll log the first 50 predictions for each epoch
600
+ wandb.log(
601
+ {
602
+ f"{prefix}/step_{int(step / 1000)}k": wandb.Table(
603
+ columns=["label_str", "pred_str"], data=str_data[:num_log]
604
+ )
605
+ },
606
+ step,
607
+ )
608
+
609
+
610
+ def create_learning_rate_fn(
611
+ num_train_steps: int, num_warmup_steps: int, learning_rate: float
612
+ ) -> Callable[[int], jnp.array]:
613
+ """Returns a linear warmup, linear_decay learning rate function."""
614
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
615
+ decay_fn = optax.linear_schedule(
616
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
617
+ )
618
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
619
+ return schedule_fn
620
+
621
+
622
+ def ctc_loss(
623
+ logits,
624
+ logits_attention_mask,
625
+ labels,
626
+ blank_id,
627
+ loss_reduction="mean",
628
+ output_emission_dict=False,
629
+ log_epsilon=-100000.0,
630
+ ):
631
+ """Computes CTC loss.
632
+ This function performs forward computation over an FSA with `N * 2` states
633
+ where `N` is the max number of labels. The states are split into two groups:
634
+ Phi states and emission states. a phi-state accepts repetition of
635
+ phi (blank)-symbols and transits to emission state when the correct label is
636
+ observed. An emission state accepts repetition of the label and transits to
637
+ the next phi states at any time (so called epsilon-transition).
638
+ Below, `B` denotes the batch size, `T` denotes the time steps in `logits`,
639
+ and `N` denotes the time steps in `labels`.
640
+ Args:
641
+ logits: (B, T, K)-array containing log-probabilities of each class.
642
+ logitpaddings: (B, T)-array. Padding indicators for `logits`.
643
+ labels: (B, N)-array containing reference integer labels.
644
+ labelpaddings: (B, N)-array. Padding indicators for `labels`. Currently,
645
+ `labels` must be right-padded, i.e. each row of `labelpaddings` must be
646
+ repetition of zeroes, followed by repetition of ones.
647
+ blank_id: Id for blank token.
648
+ loss_reduction: one of "mean", "sum", "default"
649
+ - "none": no reduction is applied.
650
+ - "mean": output loss will be divided by target lengths and then the
651
+ mean over the batch is taken.
652
+ - "sum": output loss are summed over batch
653
+ output_emission_dict: whether to output additional information about the emission probs
654
+ Returns:
655
+ A pair of `(per_seq_loss, aux)`.
656
+ per_seq_loss:
657
+ (B,)-array containing loss values for each sequence in the batch.
658
+ aux: Dictionary containing interim variables used for computing losses.
659
+ aux['logalpha_phi']: (T, B, N+1)-array. Log-forward-probabilities of each
660
+ phi-state corresponding to the n-th label.
661
+ aux['logalpha_emit']: (T, B, N)-array. Log-forward-probabilities of each
662
+ emission-state corresponding to the n-th label.
663
+ aux['logprobs_phi']: (T, B, 1)-array. Probability of the phi-symbol
664
+ corresponding to each time frame.
665
+ aux['logprobs_emit']: (T, B, N)-array. Probability of the n-th label
666
+ corresponding to each time frame.
667
+ """
668
+ # label paddings are indicated by -100
669
+ labelpaddings = labels < 0
670
+ # logit paddings are the inverse of attention_mask
671
+ logitpaddings = ~logits_attention_mask
672
+
673
+ # Copied from https://github.com/tensorflow/lingvo/blob/master/lingvo/jax/layers/ctc_objectives.py
674
+ batchsize, unused_maxinputlen, num_classes = logits.shape
675
+ batchsize_, maxlabellen = labels.shape
676
+
677
+ logprobs = jax.nn.log_softmax(logits)
678
+ labellens = maxlabellen - jnp.sum(labelpaddings, axis=1).astype(jnp.int32)
679
+
680
+ # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1].
681
+ repeat = (labels[:, :-1] == labels[:, 1:]).astype(jnp.float32)
682
+ repeat = jnp.pad(repeat, ((0, 0), (0, 1)))
683
+
684
+ logprobs_phi = logprobs[:, :, blank_id : blank_id + 1] # [B, T, 1]
685
+ logprobs_phi = jnp.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1]
686
+
687
+ one_hot = jax.nn.one_hot(labels, num_classes=num_classes) # [B, N, K]
688
+ logprobs_emit = jnp.einsum("btk,bnk->btn", logprobs, one_hot)
689
+ logprobs_emit = jnp.transpose(logprobs_emit, (1, 0, 2)) # [T, B, N]
690
+
691
+ logalpha_phi_init = jnp.ones((batchsize, maxlabellen + 1)) * log_epsilon # [B, N]
692
+ logalpha_phi_init = logalpha_phi_init.at[:, 0].set(0.0)
693
+ logalpha_emit_init = jnp.ones((batchsize, maxlabellen)) * log_epsilon # [B, N]
694
+
695
+ def loop_body(prev, x):
696
+ prev_phi, prev_emit = prev
697
+ # emit-to-phi epsilon transition, except if the next label is repetition
698
+ prev_phi_orig = prev_phi
699
+ prev_phi = prev_phi.at[:, 1:].set(jnp.logaddexp(prev_phi[:, 1:], prev_emit + log_epsilon * repeat))
700
+
701
+ logprob_emit, logprob_phi, pad = x
702
+
703
+ # phi-to-emit transition
704
+ next_emit = jnp.logaddexp(prev_phi[:, :-1] + logprob_emit, prev_emit + logprob_emit)
705
+ # self-loop transition
706
+ next_phi = prev_phi + logprob_phi
707
+ # emit-to-phi blank transition only when the next label is repetition
708
+ next_phi = next_phi.at[:, 1:].set(
709
+ jnp.logaddexp(next_phi[:, 1:], prev_emit + logprob_phi + log_epsilon * (1.0 - repeat))
710
+ )
711
+
712
+ pad = pad.reshape((batchsize, 1))
713
+ next_emit = pad * prev_emit + (1.0 - pad) * next_emit
714
+ next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi
715
+
716
+ return (next_phi, next_emit), (next_phi, next_emit)
717
+
718
+ xs = (logprobs_emit, logprobs_phi, logitpaddings.transpose((1, 0)))
719
+ _, (logalpha_phi, logalpha_emit) = jax.lax.scan(loop_body, (logalpha_phi_init, logalpha_emit_init), xs)
720
+
721
+ # last row needs to be updated with the last epsilon transition
722
+ logalpha_phi_last = logalpha_phi[-1].at[:, 1:].set(jnp.logaddexp(logalpha_phi[-1, :, 1:], logalpha_emit[-1]))
723
+ logalpha_phi = logalpha_phi.at[-1].set(logalpha_phi_last)
724
+
725
+ # extract per_seq_loss
726
+ one_hot = jax.nn.one_hot(labellens, num_classes=maxlabellen + 1) # [B, N+1]
727
+ per_seq_loss = -jnp.einsum("bn,bn->b", logalpha_phi_last, one_hot)
728
+
729
+ if loss_reduction == "mean":
730
+ target_lengths = labelpaddings.shape[-1] - labelpaddings.sum(axis=-1)
731
+ loss = (per_seq_loss / target_lengths).mean()
732
+ elif loss_reduction == "sum":
733
+ loss = per_seq_loss.sum()
734
+ else:
735
+ loss = per_seq_loss
736
+
737
+ if not output_emission_dict:
738
+ return loss
739
+
740
+ return loss, {
741
+ "logalpha_phi": logalpha_phi,
742
+ "logalpha_emit": logalpha_emit,
743
+ "logprobs_phi": logprobs_phi,
744
+ "logprobs_emit": logprobs_emit,
745
+ }
746
+
747
+
748
+ def make_dataset(data_args, seed=42):
749
+ # Pre-processing dataset
750
+ import re
751
+
752
+ def map_nst(entry):
753
+ text = entry["text"].lower()
754
+ text = text.replace("(...vær stille under dette opptaket...)", "")
755
+ text = re.sub('[áàâ]', 'a', text)
756
+ text = re.sub('[ä]', 'æ', text)
757
+ text = re.sub('[éèëê]', 'e', text)
758
+ text = re.sub('[íìïî]', 'i', text)
759
+ text = re.sub('[óòöô]', 'o', text)
760
+ text = re.sub('[ö]', 'ø', text)
761
+ text = re.sub('[ç]', 'c', text)
762
+ text = re.sub('[úùüû]', 'u', text)
763
+ # text = re.sub('\\(?=(Punktum|Komma|Utropstegn|Spørsmålstegn))', ' ', text)
764
+ text = re.sub('\s+', ' ', text)
765
+ return {"text": text}
766
+
767
+ def filter_nst(entry):
768
+ if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)):
769
+ return False # Too short
770
+ if re.match(entry["type"], "pIW|CA"):
771
+ return False # Spelling out words
772
+ return True
773
+
774
+ def filter_npsc(entry):
775
+ # False if there are digits in the text
776
+ if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)):
777
+ return False # Too short
778
+ if re.search("\d", entry["text"]):
779
+ return False
780
+ return True
781
+
782
+ def map_npsc(entry):
783
+ batch = {"text": entry["text"].lower()}
784
+ batch["text"] = re.sub('[áàâ]', 'a', batch["text"])
785
+ batch["text"] = re.sub('[ä]', 'æ', batch["text"])
786
+ batch["text"] = re.sub('[éèëê]', 'e', batch["text"])
787
+ batch["text"] = re.sub('[íìïî]', 'i', batch["text"])
788
+ batch["text"] = re.sub('[óòöô]', 'o', batch["text"])
789
+ batch["text"] = re.sub('[ö]', 'ø', batch["text"])
790
+ batch["text"] = re.sub('[ç]', 'c', batch["text"])
791
+ batch["text"] = re.sub('[úùüû]', 'u', batch["text"])
792
+ batch["text"] = re.sub('\s', ' ', batch["text"])
793
+ batch["text"] = re.sub('<ee>', 'eee', batch["text"])
794
+ batch["text"] = re.sub('<qq>', 'qqq', batch["text"])
795
+ batch["text"] = re.sub('<mm>', 'mmm', batch["text"])
796
+ batch["text"] = re.sub('<inaudible>', 'xxx', batch["text"])
797
+ # batch["text"] = re.sub('<inaudible>', '?', batch["text"])
798
+ if "<" in batch["text"]:
799
+ raise ValueError(batch["text"])
800
+ return batch
801
+
802
+ nst = datasets.load_dataset("NbAiLab/NST", "no-close")
803
+ npsc = datasets.load_dataset("NbAiLab/NPSC", "16K_mp3")
804
+ # TODO NST_hesitate
805
+
806
+ split = len(npsc[data_args.train_split_name]) / (len(npsc[data_args.train_split_name]) + len(npsc[data_args.eval_split_name])) # Use same train/val ratio as NPSC
807
+ nst_train = nst[data_args.train_split_name].train_test_split(train_size=split, seed=seed)
808
+ nst[data_args.train_split_name] = nst_train["train"]
809
+ nst[data_args.eval_split_name] = nst_train["test"]
810
+
811
+ nst = nst.filter(filter_nst).map(
812
+ map_nst,
813
+ num_proc=data_args.preprocessing_num_workers,
814
+ desc="filtering NST",
815
+ ).shuffle(seed=seed)
816
+ npsc = npsc.filter(filter_npsc).map(
817
+ map_npsc,
818
+ num_proc=data_args.preprocessing_num_workers,
819
+ desc="filtering NPSC",
820
+ ).shuffle(seed=seed)
821
+
822
+ npsc_base = npsc.remove_columns([col for col in npsc[data_args.train_split_name].column_names if col not in ["text", "audio"]])
823
+ nst_base = nst.remove_columns([col for col in nst[data_args.train_split_name].column_names if col not in ["text", "audio"]])
824
+
825
+ combined = {}
826
+ for split in data_args.train_split_name, data_args.eval_split_name, data_args.test_split_name:
827
+ probs = np.array([len(nst_base[split]), len(npsc_base[split])]) # Weight by number of examples
828
+ probs = (probs / probs.sum()).tolist()
829
+ comb = datasets.interleave_datasets([nst_base[split], npsc_base[split]], probabilities=probs, seed=seed)
830
+ combined[split] = comb
831
+
832
+ return datasets.DatasetDict(**combined)
833
+
834
+ def main():
835
+ # 1. Parse input arguments
836
+ # See all possible arguments in src/transformers/training_args.py
837
+ # or by passing the --help flag to this script.
838
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
839
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FlaxTrainingArguments))
840
+
841
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
842
+ # If we pass only one argument to the script and it's the path to a json file,
843
+ # let's parse it to get our arguments.
844
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
845
+ else:
846
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
847
+
848
+ # 2. Setup logging
849
+ # Make one log on every process with the configuration for debugging.
850
+ logging.basicConfig(
851
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
852
+ datefmt="%m/%d/%Y %H:%M:%S",
853
+ handlers=[logging.StreamHandler(sys.stdout)],
854
+ )
855
+ # Set the verbosity to info of the Transformers logger.
856
+ # We only want one process per machine to log things on the screen.
857
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
858
+ if jax.process_index() == 0:
859
+ datasets.utils.logging.set_verbosity_warning()
860
+ transformers.utils.logging.set_verbosity_info()
861
+ else:
862
+ datasets.utils.logging.set_verbosity_error()
863
+ transformers.utils.logging.set_verbosity_error()
864
+
865
+ # Set up wandb run
866
+ if jax.process_index() == 0:
867
+ wandb.init(project=data_args.wandb_project, name=data_args.wandb_name, job_type=data_args.wandb_job_type)
868
+
869
+ logger.info("Training/evaluation parameters %s", training_args)
870
+
871
+ # Set the default TPU matmul precision and display the number of devices
872
+ jax.config.update("jax_default_matmul_precision", training_args.matmul_precision)
873
+ logger.info(f"JAX devices: {jax.device_count()}, matmul precision: {training_args.matmul_precision}")
874
+
875
+ # 4. Load dataset
876
+
877
+ set_seed(training_args.seed)
878
+ raw_datasets = make_dataset(data_args, seed=training_args.seed)
879
+
880
+ # raw_datasets = DatasetDict()
881
+
882
+ # if training_args.do_train:
883
+ # raw_datasets[data_args.train_split_name] = load_dataset(
884
+ # data_args.dataset_name,
885
+ # data_args.dataset_config_name,
886
+ # split=data_args.train_split_name,
887
+ # cache_dir=data_args.dataset_cache_dir,
888
+ # use_auth_token=True if model_args.use_auth_token else None,
889
+ # )
890
+
891
+ # if training_args.do_eval:
892
+ # raw_datasets[data_args.eval_split_name] = load_dataset(
893
+ # data_args.dataset_name,
894
+ # data_args.dataset_config_name,
895
+ # split=data_args.eval_split_name,
896
+ # cache_dir=data_args.dataset_cache_dir,
897
+ # use_auth_token=True if model_args.use_auth_token else None,
898
+ # )
899
+
900
+ # if training_args.do_predict:
901
+ # test_split = data_args.test_split_name.split("+")
902
+ # for split in test_split:
903
+ # raw_datasets[split] = load_dataset(
904
+ # data_args.dataset_name,
905
+ # data_args.dataset_config_name,
906
+ # split=split,
907
+ # cache_dir=data_args.dataset_cache_dir,
908
+ # use_auth_token=True if model_args.use_auth_token else None,
909
+ # )
910
+
911
+ if not training_args.do_train and not training_args.do_eval and not training_args.do_predict:
912
+ raise ValueError(
913
+ "Cannot not train, not do evaluation and not do prediction. At least one of "
914
+ "training, evaluation or prediction has to be done."
915
+ )
916
+
917
+ # if not training, there is no need to run multiple epochs
918
+ if not training_args.do_train:
919
+ training_args.num_train_epochs = 1
920
+
921
+ if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
922
+ raise ValueError(
923
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
924
+ "Make sure to set `--audio_column_name` to the correct audio column - one of "
925
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
926
+ )
927
+
928
+ if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names:
929
+ raise ValueError(
930
+ f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
931
+ "Make sure to set `--text_column_name` to the correct text column - one of "
932
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
933
+ )
934
+
935
+ # 5. Load pretrained model, tokenizer, and feature extractor
936
+ #
937
+ # Distributed training:
938
+ # The .from_pretrained methods guarantee that only one local process can concurrently
939
+ config = Wav2Vec2Config.from_pretrained(
940
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
941
+ cache_dir=model_args.cache_dir,
942
+ revision=model_args.model_revision,
943
+ use_auth_token=True if model_args.use_auth_token else None,
944
+ )
945
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
946
+ model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
947
+ cache_dir=model_args.cache_dir,
948
+ revision=model_args.model_revision,
949
+ use_auth_token=True if model_args.use_auth_token else None,
950
+ )
951
+ tokenizer = AutoTokenizer.from_pretrained(
952
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
953
+ cache_dir=model_args.cache_dir,
954
+ revision=model_args.model_revision,
955
+ use_auth_token=True if model_args.use_auth_token else None,
956
+ )
957
+ # update config according to training args, model args, and tokenizer attributes
958
+ config.update(
959
+ {
960
+ "feat_proj_dropout": model_args.feat_proj_dropout,
961
+ "attention_dropout": model_args.attention_dropout,
962
+ "hidden_dropout": model_args.hidden_dropout,
963
+ "final_dropout": model_args.final_dropout,
964
+ "mask_time_prob": model_args.mask_time_prob,
965
+ "mask_time_length": model_args.mask_time_length,
966
+ "mask_feature_prob": model_args.mask_feature_prob,
967
+ "mask_feature_length": model_args.mask_feature_length,
968
+ "gradient_checkpointing": training_args.gradient_checkpointing,
969
+ "layerdrop": model_args.layerdrop,
970
+ "ctc_loss_reduction": model_args.ctc_loss_reduction,
971
+ "ctc_zero_infinity": model_args.ctc_zero_infinity,
972
+ "pad_token_id": tokenizer.pad_token_id,
973
+ "vocab_size": tokenizer.vocab_size, # len(tokenizer),
974
+ "activation_dropout": model_args.activation_dropout,
975
+ }
976
+ )
977
+
978
+ if tokenizer.do_lower_case and data_args.dataset_name != "librispeech_asr":
979
+ raise ValueError(
980
+ "Setting the tokenizer attribute `do_lower_case` to `True` converts all input strings to "
981
+ "uppercase prior to tokenization. This should only be done when the tokenizer is built on an uppercased corpus,"
982
+ "i.e. for the dataset `librispeech_asr` only. If your dataset is not `librispeech_asr`, the tokenizer is mostly likely "
983
+ "built on an lowercased corpus. In this case, set `tokenizer.do_lower_case` to ``False`."
984
+ )
985
+
986
+ if training_args.precision == "full_mixed":
987
+ dtype = jnp.bfloat16
988
+ training_args.mixed_precision = True
989
+ elif training_args.precision == "half_mixed":
990
+ dtype = jnp.bfloat16
991
+ training_args.mixed_precision = False
992
+ else:
993
+ dtype = jnp.float32
994
+ training_args.mixed_precision = False
995
+
996
+ model = FlaxWav2Vec2ForCTC.from_pretrained(
997
+ model_args.model_name_or_path,
998
+ config=config,
999
+ dtype=dtype,
1000
+ cache_dir=model_args.cache_dir,
1001
+ revision=model_args.model_revision,
1002
+ use_auth_token=True if model_args.use_auth_token else None,
1003
+ from_pt=True,
1004
+ )
1005
+
1006
+ # 6. Resample speech dataset ALWAYS
1007
+ raw_datasets = raw_datasets.cast_column(
1008
+ data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
1009
+ )
1010
+
1011
+ # 7. Preprocessing the datasets.
1012
+ # We need to read the audio files as arrays and tokenize the targets.
1013
+ max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
1014
+ min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate)
1015
+ max_target_length = data_args.max_label_length
1016
+ min_target_length = data_args.min_label_length
1017
+ pad_input_to_multiple_of = data_args.pad_input_to_multiple_of
1018
+ audio_column_name = data_args.audio_column_name
1019
+ num_workers = data_args.preprocessing_num_workers
1020
+ text_column_name = data_args.text_column_name
1021
+ model_input_name = feature_extractor.model_input_names[0]
1022
+ do_lower_case = data_args.do_lower_case
1023
+ dataset_name = data_args.dataset_name
1024
+ chars_to_ignore = ', ? . ! - ; : " “ % ‘ ” ?'.split(" ")
1025
+ chars_to_ignore_regex = f'[{"".join(chars_to_ignore)}]'
1026
+ # gigaspeech_punctuation = {" <comma>": ",", " <period>": ".", " <questionmark>": "?", " <exclamationpoint>": "!"}
1027
+ # gigaspeech_disfluencies = ["<other>", "<sil>"]
1028
+ # swb_disfluencies = ["[noise]", "[laughter]", "[silence]", "<a_aside>", "<b_aside>", "<e_aside>", "[laughter-",
1029
+ # "[vocalized-noise]", "_1"]
1030
+ # swb_punctuations = ["{", "}", "[", "]-", "]"]
1031
+ # earnings_disfluencies = ["<crosstalk>", "<affirmative>", "<inaudible>", "inaudible", "<laugh>", "<unk>"]
1032
+ ignore_segments = ["ignore_time_segment_in_scoring", "<noise>", "<music>", "[noise]", "[laughter]", "[silence]",
1033
+ "[vocalized-noise]", "<crosstalk>", "<affirmative>", "<inaudible>", "<laugh>", "<other>", "<sil>", ""]
1034
+
1035
+ if training_args.do_train and data_args.max_train_samples is not None:
1036
+ raw_datasets[data_args.train_split_name] = raw_datasets[data_args.train_split_name].select(range(data_args.max_train_samples))
1037
+
1038
+ if training_args.do_eval and data_args.max_eval_samples is not None:
1039
+ raw_datasets[data_args.eval_split_name] = raw_datasets[data_args.eval_split_name].select(range(data_args.max_eval_samples))
1040
+
1041
+ if training_args.do_predict and data_args.max_test_samples is not None:
1042
+ raw_datasets[data_args.test_split_name] = raw_datasets[data_args.test_split_name].select(range(data_args.max_eval_samples))
1043
+
1044
+ if training_args.do_train and data_args.remove_punctuation:
1045
+
1046
+ def remove_punctuation(batch):
1047
+ batch[text_column_name] = (
1048
+ re.sub(chars_to_ignore_regex, "", batch[text_column_name]).replace("'", "").replace('"', "")
1049
+ )
1050
+
1051
+ raw_datasets[data_args.train_split_name] = raw_datasets[data_args.train_split_name].map(
1052
+ remove_punctuation,
1053
+ num_proc=data_args.preprocessing_num_workers,
1054
+ desc="removing punctuation from train split",
1055
+ )
1056
+
1057
+ # filter data where the targets are ignored in scoring
1058
+ def is_target_labels(input_str):
1059
+ return input_str.lower() not in ignore_segments
1060
+
1061
+ raw_datasets = raw_datasets.filter(
1062
+ is_target_labels,
1063
+ num_proc=num_workers,
1064
+ input_columns=[text_column_name],
1065
+ desc="filtering data where the targets are ignored in scoring",
1066
+ )
1067
+
1068
+ def prepare_dataset(batch):
1069
+ # process audio
1070
+ try:
1071
+ sample = batch[audio_column_name]
1072
+ except ValueError:
1073
+ sample = {"array": np.array([0.]), "sampling_rate": feature_extractor.sampling_rate}
1074
+ inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
1075
+ # process audio length
1076
+ batch[model_input_name] = inputs.input_values[0]
1077
+ batch["input_length"] = len(batch["input_values"])
1078
+
1079
+ # process targets
1080
+ input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
1081
+
1082
+ # if dataset_name == "google/xtreme_s":
1083
+ # # Finally, we tokenize the processed text
1084
+ # batch["labels"] = tokenizer(input_str).input_ids
1085
+ # batch["labels_length"] = len(batch["labels"])
1086
+ # return batch
1087
+
1088
+ # # Common Voice 9
1089
+ # if input_str.startswith('"') and input_str.endswith('"'):
1090
+ # # we can remove trailing quotation marks as they do not affect the transcription
1091
+ # input_str = input_str[1:-1]
1092
+ # # normalize quotation marks
1093
+ # input_str = re.sub(r'["“”]', '"', input_str)
1094
+ # # normalize apostrophes
1095
+ # input_str = re.sub(r"[’']", "'", input_str)
1096
+ # # normalize hyphens
1097
+ # input_str = re.sub(r"[—–]", "-", input_str)
1098
+ # # replace double quotation marks with single
1099
+ # input_str = input_str.replace('""', '"')
1100
+ # if dataset_name == "mozilla-foundation/common_voice_9_0" and len(input_str):
1101
+ # # for CV9, we'll normalize the text to always finish with punctuation
1102
+ # if input_str[-1] not in [".", "?", "!"]:
1103
+ # input_str = input_str + "."
1104
+
1105
+ # # TEDLIUM-3
1106
+ # # delete the <unk> token from the text and replace spaced apostrophes with un-spaced
1107
+ # input_str = input_str.replace("<unk>", "").replace(" '", "'")
1108
+
1109
+ # # GigaSpeech
1110
+ # for disfluency in gigaspeech_disfluencies:
1111
+ # input_str = input_str.replace(disfluency, "")
1112
+ # # convert spelled out punctuation to symbolic form
1113
+ # for punctuation, replacement in gigaspeech_punctuation.items():
1114
+ # input_str = input_str.replace(punctuation, replacement)
1115
+ # if dataset_name == "speechcolab/gigaspeech" and len(input_str):
1116
+ # # for GS, we'll normalize the text to always finish with punctuation
1117
+ # if input_str[-1] not in [".", "?", "!"]:
1118
+ # input_str = input_str + "."
1119
+
1120
+ # # SWB
1121
+ # for disfluency in swb_disfluencies:
1122
+ # input_str = input_str.replace(disfluency, "")
1123
+ # # remove parenthesised text (test data only)
1124
+ # input_str = re.sub("[\(].*?[\)]", "", input_str)
1125
+ # for punctuation in swb_punctuations:
1126
+ # input_str = input_str.replace(punctuation, "")
1127
+ # # replace anomalous words with their correct transcriptions
1128
+ # split_str = input_str.split("/")
1129
+ # if len(split_str) > 1:
1130
+ # input_str = " ".join(
1131
+ # [" ".join([" ".join(i.split(" ")[:-1]) for i in split_str])] + [split_str[-1].split(" ")[-1]])
1132
+
1133
+ # # Earnings 22
1134
+ # for disfluency in earnings_disfluencies:
1135
+ # input_str = input_str.replace(disfluency, "")
1136
+ # # replace mal-formatted ellipsis
1137
+ # input_str = input_str.replace("…", ".")
1138
+
1139
+ # JIWER compliance
1140
+ # remove multiple spaces
1141
+ input_str = re.sub(r"\s\s+", " ", input_str)
1142
+ # strip trailing spaces
1143
+ input_str = input_str.strip()
1144
+
1145
+ # Finally, we tokenize the processed text
1146
+ batch["labels"] = tokenizer(input_str).input_ids
1147
+ batch["labels_length"] = len(batch["labels"])
1148
+ return batch
1149
+
1150
+ vectorized_datasets = raw_datasets.map(
1151
+ prepare_dataset,
1152
+ remove_columns=next(iter(raw_datasets.values())).column_names,
1153
+ num_proc=num_workers,
1154
+ desc="preprocess dataset",
1155
+ )
1156
+
1157
+ # filter data with inputs shorter than min_input_length or longer than max_input_length
1158
+ def is_audio_in_length_range(length):
1159
+ return length > min_input_length and length < max_input_length
1160
+
1161
+ vectorized_datasets = vectorized_datasets.filter(
1162
+ is_audio_in_length_range,
1163
+ num_proc=num_workers,
1164
+ input_columns=["input_length"],
1165
+ )
1166
+
1167
+ # filter data with targets shorter than min_target_length or longer than max_target_length
1168
+ def is_labels_in_length_range(length):
1169
+ return length > min_target_length # and length < max_target_length
1170
+
1171
+ vectorized_datasets = vectorized_datasets.filter(
1172
+ is_labels_in_length_range,
1173
+ num_proc=num_workers,
1174
+ input_columns=["labels_length"],
1175
+ )
1176
+
1177
+ # for large datasets it is advised to run the preprocessing on a
1178
+ # single machine first with `args.preprocessing_only` since there will mostly likely
1179
+ # be a timeout when running the script in distributed mode.
1180
+ # In a second step `args.preprocessing_only` can then be set to `False` to load the
1181
+ # cached dataset
1182
+ if data_args.preprocessing_only:
1183
+ cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
1184
+ logger.info(f"Data preprocessing finished. Files cached at {cache}.")
1185
+ return
1186
+
1187
+ # 8. Load Metrics
1188
+ wer_metric = load_metric("wer")
1189
+ cer_metric = load_metric("cer")
1190
+
1191
+ def compute_metrics(pred_ids: List[List[int]], label_ids: List[List[int]]):
1192
+ padded_ids = np.where(np.asarray(label_ids) == -100, tokenizer.pad_token_id, np.asarray(label_ids))
1193
+
1194
+ pred_str = tokenizer.batch_decode(pred_ids)
1195
+ # we do not want to group tokens when computing the metrics
1196
+ label_str = tokenizer.batch_decode(padded_ids, group_tokens=False)
1197
+
1198
+ wer = wer_metric.compute(predictions=pred_str, references=label_str)
1199
+ cer = cer_metric.compute(predictions=pred_str, references=label_str)
1200
+
1201
+ return {"wer": wer, "cer": cer}, pred_str, label_str
1202
+
1203
+ # 9. save feature extractor, tokenizer and config
1204
+ feature_extractor.save_pretrained(training_args.output_dir)
1205
+ tokenizer.save_pretrained(training_args.output_dir)
1206
+ config.save_pretrained(training_args.output_dir)
1207
+
1208
+ processor = AutoProcessor.from_pretrained(training_args.output_dir)
1209
+
1210
+ data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding(
1211
+ processor=processor,
1212
+ input_padding="longest",
1213
+ pad_input_to_multiple_of=pad_input_to_multiple_of,
1214
+ max_label_length=data_args.max_label_length,
1215
+ )
1216
+
1217
+ # Enable tensorboard only on the master node
1218
+ has_tensorboard = is_tensorboard_available()
1219
+ if has_tensorboard and jax.process_index() == 0:
1220
+ try:
1221
+ from flax.metrics.tensorboard import SummaryWriter
1222
+
1223
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
1224
+ except ImportError as ie:
1225
+ has_tensorboard = False
1226
+ logger.warning(
1227
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
1228
+ )
1229
+ else:
1230
+ logger.warning(
1231
+ "Unable to display metrics through TensorBoard because the package is not installed: "
1232
+ "Please run `pip install tensorboard` to enable."
1233
+ )
1234
+
1235
+ # 10. Handle the repository creation
1236
+ if training_args.push_to_hub:
1237
+ with open(os.path.join(training_args.output_dir, ".gitattributes"), "r+") as f:
1238
+ git_lfs_extensions = f.read()
1239
+ if "*.wandb" not in git_lfs_extensions:
1240
+ f.write("*.wandb filter=lfs diff=lfs merge=lfs -text")
1241
+ if training_args.hub_model_id is None:
1242
+ repo_name = get_full_repo_name(
1243
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
1244
+ )
1245
+ else:
1246
+ repo_name = training_args.hub_model_id
1247
+ repo = Repository(training_args.output_dir, clone_from=repo_name)
1248
+
1249
+ # 11. Initialize our training
1250
+ rng = jax.random.PRNGKey(training_args.seed)
1251
+ rng, dropout_rng = jax.random.split(rng)
1252
+
1253
+ # Store some constants
1254
+ max_steps = int(training_args.max_steps)
1255
+ gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
1256
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
1257
+ batch_size_per_update = train_batch_size * gradient_accumulation_steps
1258
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
1259
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
1260
+ to_dtype = to_bf16 if training_args.mixed_precision else to_fp32
1261
+
1262
+ if training_args.do_train:
1263
+ num_train_samples = len(vectorized_datasets[data_args.train_split_name])
1264
+ steps_per_epoch = num_train_samples // batch_size_per_update
1265
+ if max_steps > 0:
1266
+ num_epochs = -(training_args.max_steps // -steps_per_epoch)
1267
+ total_train_steps = max_steps
1268
+ else:
1269
+ num_epochs = int(training_args.num_train_epochs)
1270
+ total_train_steps = steps_per_epoch * num_epochs
1271
+
1272
+ # Create learning rate schedule
1273
+ # Create learning rate schedule
1274
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
1275
+ total_train_steps,
1276
+ training_args.warmup_steps,
1277
+ training_args.learning_rate,
1278
+ )
1279
+
1280
+ # We use Optax's "masking" functionality to not apply weight decay
1281
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
1282
+ # mask boolean with the same structure as the parameters.
1283
+ # The mask is True for parameters that should be decayed.
1284
+ # Note that this mask is specifically adapted for FlaxWav2Vec2 and FlaxBart.
1285
+ # For FlaxT5, one should correct the layer norm parameter naming
1286
+ # accordingly - see `run_t5_mlm_flax.py` e.g.
1287
+ def decay_mask_fn(params):
1288
+ flat_params = traverse_util.flatten_dict(params)
1289
+ layer_norm_params = [
1290
+ (name, "scale")
1291
+ for name in ["layer_norm", "self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
1292
+ ]
1293
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
1294
+ return traverse_util.unflatten_dict(flat_mask)
1295
+
1296
+ if training_args.adafactor:
1297
+ # Create Adafactor optimizer
1298
+ optim = optax.adafactor(
1299
+ learning_rate=linear_decay_lr_schedule_fn,
1300
+ dtype_momentum=jnp.bfloat16 if training_args.mixed_precision else jnp.float32,
1301
+ weight_decay_rate=training_args.weight_decay,
1302
+ weight_decay_mask=decay_mask_fn,
1303
+ )
1304
+ else:
1305
+ # Create AdamW optimizer
1306
+ optim = optax.adamw(
1307
+ learning_rate=linear_decay_lr_schedule_fn,
1308
+ b1=training_args.adam_beta1,
1309
+ b2=training_args.adam_beta2,
1310
+ eps=training_args.adam_epsilon,
1311
+ weight_decay=training_args.weight_decay,
1312
+ mask=decay_mask_fn,
1313
+ )
1314
+
1315
+ # Optax MultiSteps for gradient accumulation. We'll only call this optimizer transformation if gradient accumulation is required (i.e. gradient accumulation steps > 1)
1316
+ if training_args.multisteps and gradient_accumulation_steps > 1:
1317
+ optim = optax.MultiSteps(optim, gradient_accumulation_steps, use_grad_mean=False)
1318
+ else:
1319
+ num_epochs = 0
1320
+ total_train_steps = 0
1321
+ num_train_samples = 0
1322
+ optim = None
1323
+
1324
+ # Setup train state
1325
+ state = MixedPrecisionTrainState.create(
1326
+ apply_fn=model.__call__,
1327
+ get_attention_mask_fn=model._get_feature_vector_attention_mask,
1328
+ params=model.params,
1329
+ tx=optim,
1330
+ to_dtype=to_dtype,
1331
+ dropout_rng=dropout_rng,
1332
+ max_grad_norm=training_args.max_grad_norm,
1333
+ )
1334
+
1335
+ # Replicate the train state on each device
1336
+ state = state.replicate()
1337
+ blank_id = model.config.pad_token_id
1338
+
1339
+ # Define gradient update step fn
1340
+ def train_step(state, batch):
1341
+ # only one single rng per grad step, with or without accumulation, as the graph should be identical over one effective training batch
1342
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
1343
+
1344
+ def compute_loss(params, minibatch):
1345
+ labels = minibatch.pop("labels")
1346
+ logits = state.apply_fn(
1347
+ **minibatch,
1348
+ params=params,
1349
+ dropout_rng=dropout_rng,
1350
+ freeze_feature_encoder=model_args.freeze_feature_encoder,
1351
+ train=True,
1352
+ )[0]
1353
+ logits_mask = state.get_attention_mask_fn(logits.shape[1], batch["attention_mask"])
1354
+ loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean")
1355
+
1356
+ return loss
1357
+
1358
+ grad_fn = jax.value_and_grad(compute_loss)
1359
+
1360
+ if gradient_accumulation_steps == 1 or training_args.multisteps:
1361
+ loss, grad = grad_fn(to_dtype(state.params), batch)
1362
+
1363
+ # Custom gradient accumulation
1364
+ else:
1365
+ # add a first dimension over gradient_accumulation_steps for minibatch slices
1366
+ batch = jax.tree_map(
1367
+ lambda x: x.reshape(
1368
+ gradient_accumulation_steps, training_args.per_device_train_batch_size, *x.shape[1::]
1369
+ ),
1370
+ batch,
1371
+ )
1372
+
1373
+ def accum_minibatch_step(accum_grad, minibatch):
1374
+ # compute loss, num labels and grad over minibatch and accumulate
1375
+ loss, grad = grad_fn(to_dtype(state.params), minibatch)
1376
+ return jax.tree_map(jnp.add, accum_grad, grad), loss
1377
+
1378
+ # create an initial state for accumulating losses, num labels and gradients
1379
+ init_grad = jax.tree_map(jnp.zeros_like, to_dtype(state.params))
1380
+ # loop accum minibatch step over the number of gradient accumulation steps
1381
+ grad, loss = jax.lax.scan(accum_minibatch_step, init_grad, batch)
1382
+
1383
+ # update state
1384
+ new_state = state.apply_gradients(
1385
+ grads=grad,
1386
+ dropout_rng=new_dropout_rng,
1387
+ to_dtype=to_dtype,
1388
+ )
1389
+
1390
+ # compute gradient norms over all layers and globally for detailed monitoring
1391
+ layer_grad_norm = jax.tree_map(jnp.linalg.norm, grad)
1392
+ logs = {
1393
+ "layer_grad_norm": layer_grad_norm,
1394
+ "grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm)),
1395
+ }
1396
+
1397
+ # compute parameter norms over all layers and globally for detailed monitoring
1398
+ layer_param_norm = jax.tree_map(jnp.linalg.norm, new_state.params)
1399
+ logs["layer_param_norm"] = layer_param_norm
1400
+ logs["param_norm"] = jnp.linalg.norm(jax.tree_util.tree_leaves(layer_param_norm))
1401
+
1402
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
1403
+ metrics.update(logs)
1404
+
1405
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1406
+ # metrics = to_fp32(metrics)
1407
+
1408
+ return new_state, metrics
1409
+
1410
+ # Define eval fn
1411
+ def eval_step(params, batch):
1412
+ labels = batch.pop("labels")
1413
+ logits = model(**batch, params=params, train=False)[0]
1414
+
1415
+ logits_mask = model._get_feature_vector_attention_mask(logits.shape[1], batch["attention_mask"])
1416
+ loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean")
1417
+
1418
+ pred_ids = jnp.argmax(logits, axis=-1)
1419
+
1420
+ # summarize metrics
1421
+ metrics = {"loss": loss}
1422
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1423
+ # metrics = to_fp32(metrics)
1424
+ return metrics, pred_ids
1425
+
1426
+ # Create parallel version of the train and eval step
1427
+ if training_args.do_train:
1428
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
1429
+
1430
+ if training_args.do_eval:
1431
+ p_eval_step = jax.pmap(eval_step, "batch")
1432
+
1433
+ def run_evaluation(step):
1434
+ if training_args.do_eval:
1435
+ # ======================== Evaluating ==============================
1436
+ eval_metrics = []
1437
+ eval_preds = []
1438
+ eval_labels = []
1439
+
1440
+ # Generate eval set by sequentially sampling indices from the eval dataset and grouping by length
1441
+ eval_samples_idx = get_grouped_indices(vectorized_datasets[data_args.eval_split_name], eval_batch_size)
1442
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
1443
+
1444
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
1445
+ samples = [vectorized_datasets[data_args.eval_split_name][int(idx)] for idx in batch_idx]
1446
+ batch = data_collator(samples)
1447
+ labels = batch["labels"]
1448
+
1449
+ metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1450
+ eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1])))
1451
+ eval_metrics.append(metrics)
1452
+
1453
+ eval_labels.extend(labels)
1454
+
1455
+ # normalize eval metrics
1456
+ eval_metrics = get_metrics(eval_metrics)
1457
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
1458
+ eval_metrics = to_fp32(eval_metrics)
1459
+
1460
+ # always run compute metrics
1461
+ error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels)
1462
+ eval_metrics.update(error_rate_metric)
1463
+ error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()])
1464
+
1465
+ # Print metrics and update progress bar
1466
+ desc = f"Step... ({step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})"
1467
+ epochs.write(desc)
1468
+ epochs.desc = desc
1469
+
1470
+ # Save metrics
1471
+ write_wandb_log(eval_metrics, step, prefix="eval")
1472
+ write_wandb_pred(pred_str, label_str, step)
1473
+ # if has_tensorboard and jax.process_index() == 0:
1474
+ # write_eval_metric(summary_writer, eval_metrics, step, pred_str=pred_str)
1475
+
1476
+ def save_checkpoint(step):
1477
+ # save and push checkpoint to the hub
1478
+ if jax.process_index() == 0:
1479
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
1480
+ model.save_pretrained(training_args.output_dir, params=params)
1481
+ tokenizer.save_pretrained(training_args.output_dir)
1482
+ if training_args.push_to_hub:
1483
+ repo.push_to_hub(commit_message=f"{wandb.run.id}: saving weights and logs of step {int(step / 1000)}k", blocking=False)
1484
+
1485
+ logger.info("***** Running training *****")
1486
+ logger.info(f" Num examples = {num_train_samples}")
1487
+ logger.info(f" Num Epochs = {num_epochs}")
1488
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
1489
+ logger.info(f" Num gradient accumulation steps = {gradient_accumulation_steps}")
1490
+ logger.info(f" Total train batch size (w. parallel & distributed) = {batch_size_per_update}")
1491
+ logger.info(f" Total optimization steps = {total_train_steps}")
1492
+ logger.info(f" Gradient checkpointing: {config.gradient_checkpointing}")
1493
+ logger.info(f" Use scan: {config.use_scan}")
1494
+ logger.info(f" Fuse matmuls: {config.fuse_matmuls}")
1495
+
1496
+ train_time = cur_step = 0
1497
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
1498
+ for epoch in epochs:
1499
+ if training_args.do_train:
1500
+ # ======================== Training ================================
1501
+ train_start = time.time()
1502
+
1503
+ # Create sampling rng
1504
+ rng, input_rng = jax.random.split(rng)
1505
+
1506
+ # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length
1507
+ train_samples_idx = get_grouped_indices(vectorized_datasets[data_args.train_split_name], batch_size_per_update, input_rng)
1508
+ train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update)
1509
+
1510
+ # Gather the indices for creating the batch and do a training step
1511
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1), 1):
1512
+ samples = [vectorized_datasets[data_args.train_split_name][int(idx)] for idx in batch_idx]
1513
+ batch = data_collator(samples)
1514
+ batch = shard(batch.data)
1515
+ try:
1516
+ state, train_metric = p_train_step(state, batch)
1517
+ except TypeError as e:
1518
+ logger.warning("Encountered following error: \n", e)
1519
+
1520
+ cur_step = epoch * (num_train_samples // batch_size_per_update) + step
1521
+
1522
+ if cur_step % training_args.logging_steps == 0:
1523
+ # Save metrics
1524
+ train_metric = unreplicate(train_metric)
1525
+ train_time += time.time() - train_start
1526
+ # need to upcast all device arrays to fp32 for wandb logging (jnp.bfloat16 not supported) -> do this here OR in train_step
1527
+ write_wandb_log(to_fp32(train_metric), cur_step, prefix=data_args.train_split_name)
1528
+ # we won't log to tensorboard for now (it is fiddly logging param and grad norms on a layer-by-layer basis)
1529
+ # if has_tensorboard and jax.process_index() == 0:
1530
+ # write_train_metric(summary_writer, train_metrics, train_time, cur_step)
1531
+
1532
+ epochs.write(
1533
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']}, Gradient Norm: {train_metric['grad_norm']})"
1534
+ )
1535
+
1536
+ if cur_step % total_train_steps == 0:
1537
+ break
1538
+
1539
+ if training_args.eval_steps and cur_step % training_args.eval_steps == 0:
1540
+ run_evaluation(cur_step)
1541
+
1542
+ if cur_step % training_args.save_steps == 0:
1543
+ save_checkpoint(cur_step)
1544
+
1545
+ if training_args.eval_steps == 0 and (epoch + 1) != num_epochs:
1546
+ # run evaluation at the end of the epoch if eval steps are not specified
1547
+ run_evaluation(cur_step)
1548
+ save_checkpoint(cur_step)
1549
+
1550
+ if training_args.do_train:
1551
+ save_checkpoint(cur_step)
1552
+
1553
+ cur_step = max_steps if max_steps > 0 else cur_step # set step to max steps so that eval happens in alignment with training
1554
+
1555
+ if training_args.do_eval:
1556
+ run_evaluation(cur_step)
1557
+
1558
+ # TODO: collapse 'do_predict' into the run_evaluation function
1559
+ if training_args.do_predict:
1560
+ for split in [data_args.test_split_name]:
1561
+ # ======================== Evaluating ==============================
1562
+ eval_metrics = []
1563
+ eval_preds = []
1564
+ eval_labels = []
1565
+
1566
+ # Generate eval set by sequentially sampling indices from the test dataset and grouping by length
1567
+ eval_samples_idx = get_grouped_indices(vectorized_datasets[split], eval_batch_size)
1568
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
1569
+
1570
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc=f"Predicting {split}...", position=2)):
1571
+ samples = [vectorized_datasets[split][int(idx)] for idx in batch_idx]
1572
+ batch = data_collator(samples)
1573
+ labels = batch["labels"]
1574
+
1575
+ metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1576
+ eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1])))
1577
+ eval_metrics.append(metrics)
1578
+
1579
+ eval_labels.extend(labels)
1580
+
1581
+ # normalize eval metrics
1582
+ eval_metrics = get_metrics(eval_metrics)
1583
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
1584
+ eval_metrics = to_fp32(eval_metrics)
1585
+
1586
+ # always run compute metrics
1587
+ error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels)
1588
+ eval_metrics.update(error_rate_metric)
1589
+ error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()])
1590
+
1591
+ # Print metrics and update progress bar
1592
+ desc = f"Step... ({cur_step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})"
1593
+ epochs.write(desc)
1594
+ epochs.desc = desc
1595
+
1596
+ # Save metrics
1597
+ write_wandb_log(eval_metrics, cur_step, prefix=split)
1598
+ write_wandb_pred(pred_str, label_str, cur_step, prefix=split)
1599
+ # if has_tensorboard and jax.process_index() == 0:
1600
+ # write_eval_metric(summary_writer, eval_metrics, cur_step, pred_str=pred_str)
1601
+
1602
+
1603
+ if __name__ == "__main__":
1604
+ main()
special_tokens_map.json ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ {
4
+ "content": "<s>",
5
+ "lstrip": false,
6
+ "normalized": true,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ },
10
+ {
11
+ "content": "</s>",
12
+ "lstrip": false,
13
+ "normalized": true,
14
+ "rstrip": false,
15
+ "single_word": false
16
+ },
17
+ {
18
+ "content": "<s>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ {
25
+ "content": "</s>",
26
+ "lstrip": false,
27
+ "normalized": true,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ },
31
+ {
32
+ "content": "<s>",
33
+ "lstrip": false,
34
+ "normalized": true,
35
+ "rstrip": false,
36
+ "single_word": false
37
+ },
38
+ {
39
+ "content": "</s>",
40
+ "lstrip": false,
41
+ "normalized": true,
42
+ "rstrip": false,
43
+ "single_word": false
44
+ },
45
+ {
46
+ "content": "<s>",
47
+ "lstrip": false,
48
+ "normalized": true,
49
+ "rstrip": false,
50
+ "single_word": false
51
+ },
52
+ {
53
+ "content": "</s>",
54
+ "lstrip": false,
55
+ "normalized": true,
56
+ "rstrip": false,
57
+ "single_word": false
58
+ },
59
+ {
60
+ "content": "<s>",
61
+ "lstrip": false,
62
+ "normalized": true,
63
+ "rstrip": false,
64
+ "single_word": false
65
+ },
66
+ {
67
+ "content": "</s>",
68
+ "lstrip": false,
69
+ "normalized": true,
70
+ "rstrip": false,
71
+ "single_word": false
72
+ },
73
+ {
74
+ "content": "<s>",
75
+ "lstrip": false,
76
+ "normalized": true,
77
+ "rstrip": false,
78
+ "single_word": false
79
+ },
80
+ {
81
+ "content": "</s>",
82
+ "lstrip": false,
83
+ "normalized": true,
84
+ "rstrip": false,
85
+ "single_word": false
86
+ },
87
+ {
88
+ "content": "<s>",
89
+ "lstrip": false,
90
+ "normalized": true,
91
+ "rstrip": false,
92
+ "single_word": false
93
+ },
94
+ {
95
+ "content": "</s>",
96
+ "lstrip": false,
97
+ "normalized": true,
98
+ "rstrip": false,
99
+ "single_word": false
100
+ },
101
+ {
102
+ "content": "<s>",
103
+ "lstrip": false,
104
+ "normalized": true,
105
+ "rstrip": false,
106
+ "single_word": false
107
+ },
108
+ {
109
+ "content": "</s>",
110
+ "lstrip": false,
111
+ "normalized": true,
112
+ "rstrip": false,
113
+ "single_word": false
114
+ },
115
+ {
116
+ "content": "<s>",
117
+ "lstrip": false,
118
+ "normalized": true,
119
+ "rstrip": false,
120
+ "single_word": false
121
+ },
122
+ {
123
+ "content": "</s>",
124
+ "lstrip": false,
125
+ "normalized": true,
126
+ "rstrip": false,
127
+ "single_word": false
128
+ },
129
+ {
130
+ "content": "<s>",
131
+ "lstrip": false,
132
+ "normalized": true,
133
+ "rstrip": false,
134
+ "single_word": false
135
+ },
136
+ {
137
+ "content": "</s>",
138
+ "lstrip": false,
139
+ "normalized": true,
140
+ "rstrip": false,
141
+ "single_word": false
142
+ },
143
+ {
144
+ "content": "<s>",
145
+ "lstrip": false,
146
+ "normalized": true,
147
+ "rstrip": false,
148
+ "single_word": false
149
+ },
150
+ {
151
+ "content": "</s>",
152
+ "lstrip": false,
153
+ "normalized": true,
154
+ "rstrip": false,
155
+ "single_word": false
156
+ },
157
+ {
158
+ "content": "<s>",
159
+ "lstrip": false,
160
+ "normalized": true,
161
+ "rstrip": false,
162
+ "single_word": false
163
+ },
164
+ {
165
+ "content": "</s>",
166
+ "lstrip": false,
167
+ "normalized": true,
168
+ "rstrip": false,
169
+ "single_word": false
170
+ },
171
+ {
172
+ "content": "<s>",
173
+ "lstrip": false,
174
+ "normalized": true,
175
+ "rstrip": false,
176
+ "single_word": false
177
+ },
178
+ {
179
+ "content": "</s>",
180
+ "lstrip": false,
181
+ "normalized": true,
182
+ "rstrip": false,
183
+ "single_word": false
184
+ }
185
+ ],
186
+ "bos_token": "<s>",
187
+ "eos_token": "</s>",
188
+ "pad_token": "[PAD]",
189
+ "unk_token": "[UNK]"
190
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "do_lower_case": false,
4
+ "eos_token": "</s>",
5
+ "name_or_path": "./",
6
+ "pad_token": "[PAD]",
7
+ "replace_word_delimiter_char": " ",
8
+ "special_tokens_map_file": null,
9
+ "tokenizer_class": "Wav2Vec2CTCTokenizer",
10
+ "unk_token": "[UNK]",
11
+ "word_delimiter_token": "|"
12
+ }
vocab.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "(": 1,
3
+ ")": 2,
4
+ "0": 3,
5
+ "3": 4,
6
+ "7": 5,
7
+ "8": 6,
8
+ "9": 7,
9
+ "[PAD]": 38,
10
+ "[UNK]": 37,
11
+ "a": 8,
12
+ "b": 9,
13
+ "c": 10,
14
+ "d": 11,
15
+ "e": 12,
16
+ "f": 13,
17
+ "g": 14,
18
+ "h": 15,
19
+ "i": 16,
20
+ "j": 17,
21
+ "k": 18,
22
+ "l": 19,
23
+ "m": 20,
24
+ "n": 21,
25
+ "o": 22,
26
+ "p": 23,
27
+ "q": 24,
28
+ "r": 25,
29
+ "s": 26,
30
+ "t": 27,
31
+ "u": 28,
32
+ "v": 29,
33
+ "w": 30,
34
+ "x": 31,
35
+ "y": 32,
36
+ "z": 33,
37
+ "|": 0,
38
+ "å": 34,
39
+ "æ": 35,
40
+ "ø": 36
41
+ }
wandb/debug-internal.log ADDED
@@ -0,0 +1 @@
 
 
1
+ run-20220730_174606-j2u4n7h4/logs/debug-internal.log
wandb/debug.log ADDED
@@ -0,0 +1 @@
 
 
1
+ run-20220730_174606-j2u4n7h4/logs/debug.log
wandb/latest-run ADDED
@@ -0,0 +1 @@
 
 
1
+ run-20220730_174606-j2u4n7h4
wandb/run-20220729_183213-356uc50u/files/code/run_flax_speech_recognition_ctc.py ADDED
@@ -0,0 +1,1596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2022 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the Flax library models for connectionist temporal classification (CTC) speech recognition.
18
+ """
19
+ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
+
21
+ import logging
22
+ import math
23
+ import os
24
+ import re
25
+ import sys
26
+ import time
27
+ from dataclasses import dataclass, field
28
+ from pathlib import Path
29
+ from typing import Any, Callable, Dict, List, Optional, Union
30
+
31
+ import datasets
32
+ import numpy as np
33
+ from datasets import DatasetDict, load_dataset, load_metric
34
+ from tqdm import tqdm
35
+
36
+ import flax
37
+ import jax
38
+ import jax.numpy as jnp
39
+ import optax
40
+ import transformers
41
+ import wandb as wandb
42
+ from flax import core, jax_utils, struct, traverse_util
43
+ from flax.jax_utils import unreplicate, pad_shard_unpad
44
+ from flax.training.common_utils import get_metrics, shard, shard_prng_key
45
+ from huggingface_hub import Repository
46
+ from models import Wav2Vec2Config, FlaxWav2Vec2ForCTC
47
+ from optax._src import linear_algebra
48
+ from transformers import (
49
+ AutoFeatureExtractor,
50
+ AutoProcessor,
51
+ AutoTokenizer,
52
+ HfArgumentParser,
53
+ TrainingArguments,
54
+ is_tensorboard_available,
55
+ set_seed,
56
+ )
57
+ from transformers.file_utils import get_full_repo_name
58
+ from transformers.utils import check_min_version
59
+ from transformers.utils.versions import require_version
60
+
61
+
62
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
63
+ check_min_version("4.17.0.dev0")
64
+
65
+ require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
66
+
67
+ logger = logging.getLogger(__name__)
68
+
69
+
70
+ @flax.struct.dataclass
71
+ class ModelArguments:
72
+ """
73
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
74
+ """
75
+
76
+ model_name_or_path: str = field(
77
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
78
+ )
79
+ config_name: Optional[str] = field(
80
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
81
+ )
82
+ tokenizer_name: Optional[str] = field(
83
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
84
+ )
85
+ feature_extractor_name: Optional[str] = field(
86
+ default=None, metadata={"help": "feature extractor name or path if not the same as model_name"}
87
+ )
88
+ cache_dir: Optional[str] = field(
89
+ default=None,
90
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
91
+ )
92
+ use_fast_tokenizer: bool = field(
93
+ default=True,
94
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
95
+ )
96
+ model_revision: str = field(
97
+ default="main",
98
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
99
+ )
100
+ use_auth_token: bool = field(
101
+ default=False,
102
+ metadata={
103
+ "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
104
+ "with private models)."
105
+ },
106
+ )
107
+ freeze_feature_encoder: bool = field(
108
+ default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
109
+ )
110
+ attention_dropout: float = field(
111
+ default=0.0, metadata={"help": "The dropout ratio for the attention probabilities."}
112
+ )
113
+ activation_dropout: float = field(
114
+ default=0.1,
115
+ metadata={
116
+ "help": "The hidden activation dropout probability in the embeddings, encoder, and pooler."
117
+ },
118
+ )
119
+ hidden_dropout: float = field(
120
+ default=0.1,
121
+ metadata={
122
+ "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler."
123
+ },
124
+ )
125
+ feat_proj_dropout: float = field(
126
+ default=0.0,
127
+ metadata={
128
+ "help": "The feat proj dropout probability for feature encoder representations."
129
+ },
130
+ )
131
+ final_dropout: float = field(
132
+ default=0.0,
133
+ metadata={"help": "The dropout probability for the final projection layer."},
134
+ )
135
+ mask_time_prob: float = field(
136
+ default=0.1,
137
+ metadata={
138
+ "help": "The spec aug dropout probability for feature encoder representations."
139
+ },
140
+ )
141
+ mask_time_length: int = field(
142
+ default=10,
143
+ metadata={"help": "Length of vector span to mask along the time axis."},
144
+ )
145
+ mask_feature_prob: float = field(
146
+ default=0.0,
147
+ metadata={
148
+ "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector"
149
+ "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis."
150
+ },
151
+ )
152
+ mask_feature_length: int = field(
153
+ default=10,
154
+ metadata={"help": "Length of vector span to mask along the feature axis."},
155
+ )
156
+ layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."})
157
+ ctc_loss_reduction: Optional[str] = field(
158
+ default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
159
+ )
160
+ ctc_zero_infinity: Optional[bool] = field(
161
+ default=False, metadata={"help": "If True, will try yo aboud the CTC loss goinf to infinity."}
162
+ )
163
+
164
+
165
+ @flax.struct.dataclass
166
+ class DataTrainingArguments:
167
+ """
168
+ Arguments pertaining to what data we are going to input our model for training and eval.
169
+ """
170
+
171
+ dataset_name: str = field(
172
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
173
+ )
174
+ dataset_config_name: Optional[str] = field(
175
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
176
+ )
177
+ text_column: Optional[str] = field(
178
+ default=None,
179
+ metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
180
+ )
181
+ dataset_cache_dir: Optional[str] = field(
182
+ default=None, metadata={"help": "Path to cache directory for saving and loading datasets"}
183
+ )
184
+ overwrite_cache: bool = field(
185
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
186
+ )
187
+ preprocessing_num_workers: Optional[int] = field(
188
+ default=None,
189
+ metadata={"help": "The number of processes to use for the preprocessing."},
190
+ )
191
+ max_train_samples: Optional[int] = field(
192
+ default=None,
193
+ metadata={
194
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
195
+ "value if set."
196
+ },
197
+ )
198
+ max_eval_samples: Optional[int] = field(
199
+ default=None,
200
+ metadata={
201
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
202
+ "value if set."
203
+ },
204
+ )
205
+ max_test_samples: Optional[int] = field(
206
+ default=None,
207
+ metadata={
208
+ "help": "For debugging purposes or quicker training, truncate the number of test examples to this "
209
+ "value if set."
210
+ },
211
+ )
212
+ audio_column_name: str = field(
213
+ default="audio",
214
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
215
+ )
216
+ text_column_name: str = field(
217
+ default="text",
218
+ metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
219
+ )
220
+ max_duration_in_seconds: float = field(
221
+ default=20.0,
222
+ metadata={
223
+ "help": "Truncate audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`"
224
+ },
225
+ )
226
+ min_duration_in_seconds: float = field(
227
+ default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
228
+ )
229
+ max_label_length: Optional[int] = field(
230
+ default=512,
231
+ metadata={
232
+ "help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
233
+ "than this will be filtered."
234
+ },
235
+ )
236
+ min_label_length: Optional[int] = field(
237
+ default=2,
238
+ metadata={
239
+ "help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
240
+ "than this will be filtered."
241
+ },
242
+ )
243
+ pad_input_to_multiple_of: Optional[int] = field(
244
+ default=32000,
245
+ metadata={
246
+ "help": "If set will pad the input sequence to a multiple of the provided value. "
247
+ "This is important to avoid triggering recompilations on TPU."
248
+ },
249
+ )
250
+ pad_target_to_multiple_of: Optional[int] = field(
251
+ default=None,
252
+ metadata={
253
+ "help": "If set will pad the target sequence to a multiple of the provided value. "
254
+ "This is important to avoid triggering recompilations on TPU."
255
+ },
256
+ )
257
+ preprocessing_only: bool = field(
258
+ default=False,
259
+ metadata={
260
+ "help": "Whether to only do data preprocessing and skip training. "
261
+ "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
262
+ "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
263
+ "so that the cached datasets can consequently be loaded in distributed training"
264
+ },
265
+ )
266
+ train_split_name: str = field(
267
+ default="train",
268
+ metadata={
269
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
270
+ },
271
+ )
272
+ eval_split_name: str = field(
273
+ default="validation",
274
+ metadata={
275
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
276
+ },
277
+ )
278
+ do_lower_case: bool = field(
279
+ default=True,
280
+ metadata={"help": "Whether the target text should be lower cased."},
281
+ )
282
+ wandb_project: str = field(
283
+ default="flax-speech-recognition-ctc",
284
+ metadata={"help": "The name of the wandb project."},
285
+ )
286
+ wandb_name: str = field(
287
+ default=None,
288
+ metadata={"help": "The name of the wandb run."},
289
+ )
290
+ wandb_job_type: str = field(
291
+ default="CTC",
292
+ metadata={"help": "The name of the wandb job type."},
293
+ )
294
+ test_split_name: str = field(
295
+ default="test",
296
+ metadata={"help": "The name of the test data set split to use (via the datasets library). Defaults to 'test'"},
297
+ )
298
+ remove_punctuation: bool = field(
299
+ default=False, metadata={"help": "Whether or not to remove punctuation during training."}
300
+ )
301
+
302
+
303
+ # @flax.struct.dataclass
304
+ @dataclass
305
+ class FlaxTrainingArguments(TrainingArguments):
306
+ precision: str = field(
307
+ default="full",
308
+ metadata={
309
+ "help": "Whether to enable mixed-precision training. If true, the optimizer is stored in half-precision (bfloat16) and computations are executed in half-precision"
310
+ "**Note that this only specifies the dtype of the computation and optimizer state. It does not influence the dtype of model parameters.**"
311
+ },
312
+ )
313
+ matmul_precision: str = field(
314
+ default="default",
315
+ metadata={
316
+ "help": "Default floating-point precision of internal computations used in TPU matrix multiplications and convolutions. "
317
+ "This configuration option controls the default precision for JAX operations that take an optional precision argument (e.g. `lax.conv_general_dilated` and `lax.dot`). "
318
+ "This configuration option does not change the behaviours of such calls with explicit precision arguments; "
319
+ "it only changes the behaviors of calls with no such argument provided. "
320
+ "One of `['highest', 'float32', 'high', 'bfloat16_3x', 'default', 'bfloat16', 'fastest', None]`."
321
+ },
322
+ )
323
+ multisteps: bool = field(
324
+ default=False,
325
+ metadata={
326
+ "help": "Whether to use Optax MultiSteps for gradient accumulation. If `False` (default) and `gradient_accumulation_steps > 1`, "
327
+ "a custom gradient accumulation implementation will be employed."
328
+ },
329
+ )
330
+
331
+
332
+ def to_fp32(t):
333
+ return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
334
+
335
+
336
+ def to_bf16(t):
337
+ return jax.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t)
338
+
339
+
340
+ class MixedPrecisionTrainState(struct.PyTreeNode):
341
+ """Train state for use with a single Optax optimizer.
342
+ Adapted from flax train_state https://github.com/google/flax/blob/main/flax/training/train_state.py
343
+
344
+ Synopsis::
345
+
346
+ state = TrainState.create(
347
+ apply_fn=model.apply,
348
+ params=variables['params'],
349
+ tx=tx)
350
+ grad_fn = jax.grad(make_loss_fn(state.apply_fn))
351
+ for batch in data:
352
+ grads = grad_fn(state.params, batch)
353
+ state = state.apply_gradients(grads=grads)
354
+
355
+ Args:
356
+ step: Counter starts at 0 and is incremented by every call to
357
+ `.apply_gradients()`.
358
+ apply_fn: Usually set to `model.apply()`. Kept in this dataclass for
359
+ convenience to have a shorter params list for the `train_step()` function
360
+ in your training loop.
361
+ params: The parameters to be updated by `tx` and used by `apply_fn`.
362
+ tx: An Optax gradient transformation.
363
+ opt_state: The state for `tx`.
364
+ dropout_rng: PRNG key for stochastic operations.
365
+ bf16: Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training.
366
+ """
367
+
368
+ step: int
369
+ apply_fn: Callable = struct.field(pytree_node=False)
370
+ get_attention_mask_fn: Callable = struct.field(pytree_node=False)
371
+ params: core.FrozenDict[str, Any]
372
+ tx: optax.GradientTransformation = struct.field(pytree_node=False)
373
+ opt_state: optax.OptState
374
+ dropout_rng: jnp.ndarray
375
+ max_grad_norm: Optional[float] = 1.0
376
+
377
+ def apply_gradients(self, *, grads, to_dtype, **kwargs):
378
+ """Updates `step`, `params`, `opt_state` and `**kwargs` in return value.
379
+
380
+ Note that internally this function calls `.tx.update()` followed by a call
381
+ to `optax.apply_updates()` to update `params` and `opt_state`.
382
+
383
+ Args:
384
+ grads: Gradients that have the same pytree structure as `.params`.
385
+ **kwargs: Additional dataclass attributes that should be `.replace()`-ed.
386
+
387
+ Returns:
388
+ An updated instance of `self` with `step` incremented by one, `params`
389
+ and `opt_state` updated by applying `grads`, and additional attributes
390
+ replaced as specified by `kwargs`.
391
+ """
392
+
393
+ # clip gradients by global l2 norm
394
+ casted_max_grad_norm = to_dtype(self.max_grad_norm)
395
+ g_norm = linear_algebra.global_norm(grads)
396
+ g_norm = jnp.maximum(casted_max_grad_norm, g_norm)
397
+ grads = jax.tree_map(lambda t: (t / g_norm) * casted_max_grad_norm, grads)
398
+
399
+ # perform update step in fp32 and subsequently downcast optimizer states if mixed precision training
400
+ # grads and opt_state in bf16 (need to upcast), params in fp32 (leave as is)
401
+ updates, new_opt_state = self.tx.update(to_fp32(grads), to_fp32(self.opt_state), self.params)
402
+
403
+ new_params = optax.apply_updates(self.params, updates)
404
+ return self.replace(
405
+ step=self.step + 1,
406
+ params=new_params,
407
+ opt_state=to_dtype(new_opt_state),
408
+ **kwargs,
409
+ )
410
+
411
+ @classmethod
412
+ def create(cls, *, apply_fn, params, tx, to_dtype, **kwargs):
413
+ """Creates a new instance with `step=0` and initialized `opt_state`."""
414
+ # downcast optimizer state to bf16 if mixed-precision training
415
+ opt_state = tx.init(to_dtype(params)) if tx is not None else None
416
+ return cls(
417
+ step=0,
418
+ apply_fn=apply_fn,
419
+ params=params,
420
+ tx=tx,
421
+ opt_state=opt_state,
422
+ **kwargs,
423
+ )
424
+
425
+ def replicate(self):
426
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
427
+
428
+
429
+ @flax.struct.dataclass
430
+ class FlaxDataCollatorSpeechSeq2SeqWithPadding:
431
+ """
432
+ Data collator that will dynamically pad the inputs received.
433
+ Args:
434
+ processor ([`Wav2Vec2Processor`])
435
+ The processor used for proccessing the data.
436
+ decoder_start_token_id (:obj: `int`)
437
+ The begin-of-sentence of the decoder.
438
+ input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
439
+ Select a strategy to pad the returned input sequences (according to the model's padding side and padding index)
440
+ among:
441
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
442
+ sequence if provided).
443
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
444
+ maximum acceptable input length for the model if that argument is not provided.
445
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
446
+ different lengths).
447
+ target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
448
+ Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
449
+ See above for details.
450
+ max_input_length (:obj:`float`, `optional`):
451
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
452
+ pad_input_to_multiple_of (:obj:`int`, `optional`):
453
+ If set will pad the input sequence to a multiple of the provided value.
454
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
455
+ 7.5 (Volta).
456
+ pad_target_to_multiple_of (:obj:`int`, `optional`):
457
+ If set will pad the target sequence to a multiple of the provided value.
458
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
459
+ 7.5 (Volta).
460
+ """
461
+
462
+ processor: Any
463
+ input_padding: Union[bool, str] = "longest"
464
+ label_padding: Union[bool, str] = "max_length"
465
+ pad_input_to_multiple_of: Optional[int] = None
466
+ pad_to_multiple_of_label: Optional[int] = None
467
+ max_input_length: Optional[float] = None
468
+ max_label_length: Optional[float] = None
469
+
470
+ def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
471
+ # split inputs and labels since they have to be of different lengths and need
472
+ # different padding methods
473
+ input_features = [{"input_values": feature["input_values"]} for feature in features]
474
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
475
+
476
+ # reformat list to dict and set to pytorch format
477
+ batch = self.processor.feature_extractor.pad(
478
+ input_features,
479
+ max_length=self.max_input_length,
480
+ padding=self.input_padding,
481
+ pad_to_multiple_of=self.pad_input_to_multiple_of,
482
+ return_tensors="np",
483
+ )
484
+
485
+ labels_batch = self.processor.tokenizer.pad(
486
+ label_features,
487
+ max_length=self.max_label_length,
488
+ padding=self.label_padding,
489
+ pad_to_multiple_of=self.pad_to_multiple_of_label,
490
+ return_tensors="np",
491
+ )
492
+
493
+ labels = labels_batch["input_ids"]
494
+ labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1))
495
+ labels = labels.filled(fill_value=-100)
496
+
497
+ batch["labels"] = labels
498
+
499
+ return batch
500
+
501
+
502
+ def get_grouped_indices(
503
+ dataset, batch_size: int, rng: Optional[List[int]] = None, mega_batch_mult: Optional[int] = None
504
+ ) -> np.array:
505
+ """
506
+ Adapted from the `get_length_grouped_indices` function in the PyTorch Trainer utils file (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_pt_utils.py#L486)
507
+ Function that returns a list of indices in which each slice of `batch_size` consecutive indices correspond to elements of similar
508
+ lengths. To do this, the indices are:
509
+
510
+ - randomly permuted (if a JAX rng is specified)
511
+ - grouped in mega-batches of size `mega_batch_mult * batch_size`
512
+ - sorted by length in each mega-batch
513
+
514
+ The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
515
+ maximum length placed first, so that an OOM happens sooner rather than later.
516
+ """
517
+ lengths = dataset["input_length"]
518
+
519
+ # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
520
+ if mega_batch_mult is None:
521
+ mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
522
+ # Just in case, for tiny datasets
523
+ if mega_batch_mult == 0:
524
+ mega_batch_mult = 1
525
+
526
+ # We need to use JAX for the random permutation as the PRNG key will be set based on the seed outside of the sampler.
527
+ num_samples = len(lengths)
528
+ indices = jax.random.permutation(rng, np.arange(num_samples)) if rng is not None else np.arange(num_samples)
529
+
530
+ megabatch_size = mega_batch_mult * batch_size
531
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
532
+ megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches]
533
+
534
+ # The rest is to get the biggest batch first.
535
+ # Since each megabatch is sorted by descending length, the longest element is the first
536
+ megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
537
+ max_idx = np.argmax(megabatch_maximums).item()
538
+ # Switch to put the longest batch in first position
539
+ # (note that this is different to the PT grouped sampler in which we only put the longest element in the first position, and not its batch)
540
+ megabatches[0], megabatches[max_idx] = megabatches[max_idx], megabatches[0]
541
+
542
+ megabatches = np.array([i for megabatch in megabatches for i in megabatch])
543
+
544
+ return megabatches
545
+
546
+
547
+ def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray:
548
+ """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
549
+ the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned."""
550
+ num_samples = len(samples_idx)
551
+ if drop_last:
552
+ samples_to_remove = num_samples % batch_size
553
+ if samples_to_remove != 0:
554
+ samples_idx = samples_idx[:-samples_to_remove]
555
+ sections_split = num_samples // batch_size
556
+ samples_idx = samples_idx.reshape((sections_split, batch_size))
557
+ else:
558
+ sections_split = math.ceil(num_samples / batch_size)
559
+ samples_idx = np.array_split(samples_idx, sections_split)
560
+ return samples_idx
561
+
562
+
563
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
564
+ summary_writer.scalar("train_time", train_time, step)
565
+
566
+ train_metrics = get_metrics(train_metrics)
567
+ for key, vals in train_metrics.items():
568
+ tag = f"train_{key}"
569
+ for i, val in enumerate(vals):
570
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
571
+
572
+
573
+ def write_eval_metric(summary_writer, eval_metrics, step, pred_str=None):
574
+ for metric_name, value in eval_metrics.items():
575
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
576
+
577
+ if pred_str is not None:
578
+ # write output actual predictions for debugging
579
+ summary_writer.text("eval_predictions", "\n".join(pred_str), step)
580
+
581
+
582
+ def write_wandb_log(metrics, step, prefix=None):
583
+ if jax.process_index() == 0:
584
+ log_metrics = {}
585
+ for k, v in metrics.items():
586
+ if "layer" in k:
587
+ log_metrics[f"{k}/"] = v
588
+ elif prefix is not None:
589
+ log_metrics[f"{prefix}/{k}"] = v
590
+ else:
591
+ log_metrics[k] = v
592
+ wandb.log(log_metrics, step)
593
+
594
+
595
+ def write_wandb_pred(pred_str, label_str, step, num_log=50, prefix="eval"):
596
+ if jax.process_index() == 0:
597
+ # convert str data to a wandb compatible format
598
+ str_data = [[label_str[i], pred_str[i]] for i in range(len(pred_str))]
599
+ # we'll log the first 50 predictions for each epoch
600
+ wandb.log(
601
+ {
602
+ f"{prefix}/step_{int(step / 1000)}k": wandb.Table(
603
+ columns=["label_str", "pred_str"], data=str_data[:num_log]
604
+ )
605
+ },
606
+ step,
607
+ )
608
+
609
+
610
+ def create_learning_rate_fn(
611
+ num_train_steps: int, num_warmup_steps: int, learning_rate: float
612
+ ) -> Callable[[int], jnp.array]:
613
+ """Returns a linear warmup, linear_decay learning rate function."""
614
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
615
+ decay_fn = optax.linear_schedule(
616
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
617
+ )
618
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
619
+ return schedule_fn
620
+
621
+
622
+ def ctc_loss(
623
+ logits,
624
+ logits_attention_mask,
625
+ labels,
626
+ blank_id,
627
+ loss_reduction="mean",
628
+ output_emission_dict=False,
629
+ log_epsilon=-100000.0,
630
+ ):
631
+ """Computes CTC loss.
632
+ This function performs forward computation over an FSA with `N * 2` states
633
+ where `N` is the max number of labels. The states are split into two groups:
634
+ Phi states and emission states. a phi-state accepts repetition of
635
+ phi (blank)-symbols and transits to emission state when the correct label is
636
+ observed. An emission state accepts repetition of the label and transits to
637
+ the next phi states at any time (so called epsilon-transition).
638
+ Below, `B` denotes the batch size, `T` denotes the time steps in `logits`,
639
+ and `N` denotes the time steps in `labels`.
640
+ Args:
641
+ logits: (B, T, K)-array containing log-probabilities of each class.
642
+ logitpaddings: (B, T)-array. Padding indicators for `logits`.
643
+ labels: (B, N)-array containing reference integer labels.
644
+ labelpaddings: (B, N)-array. Padding indicators for `labels`. Currently,
645
+ `labels` must be right-padded, i.e. each row of `labelpaddings` must be
646
+ repetition of zeroes, followed by repetition of ones.
647
+ blank_id: Id for blank token.
648
+ loss_reduction: one of "mean", "sum", "default"
649
+ - "none": no reduction is applied.
650
+ - "mean": output loss will be divided by target lengths and then the
651
+ mean over the batch is taken.
652
+ - "sum": output loss are summed over batch
653
+ output_emission_dict: whether to output additional information about the emission probs
654
+ Returns:
655
+ A pair of `(per_seq_loss, aux)`.
656
+ per_seq_loss:
657
+ (B,)-array containing loss values for each sequence in the batch.
658
+ aux: Dictionary containing interim variables used for computing losses.
659
+ aux['logalpha_phi']: (T, B, N+1)-array. Log-forward-probabilities of each
660
+ phi-state corresponding to the n-th label.
661
+ aux['logalpha_emit']: (T, B, N)-array. Log-forward-probabilities of each
662
+ emission-state corresponding to the n-th label.
663
+ aux['logprobs_phi']: (T, B, 1)-array. Probability of the phi-symbol
664
+ corresponding to each time frame.
665
+ aux['logprobs_emit']: (T, B, N)-array. Probability of the n-th label
666
+ corresponding to each time frame.
667
+ """
668
+ # label paddings are indicated by -100
669
+ labelpaddings = labels < 0
670
+ # logit paddings are the inverse of attention_mask
671
+ logitpaddings = ~logits_attention_mask
672
+
673
+ # Copied from https://github.com/tensorflow/lingvo/blob/master/lingvo/jax/layers/ctc_objectives.py
674
+ batchsize, unused_maxinputlen, num_classes = logits.shape
675
+ batchsize_, maxlabellen = labels.shape
676
+
677
+ logprobs = jax.nn.log_softmax(logits)
678
+ labellens = maxlabellen - jnp.sum(labelpaddings, axis=1).astype(jnp.int32)
679
+
680
+ # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1].
681
+ repeat = (labels[:, :-1] == labels[:, 1:]).astype(jnp.float32)
682
+ repeat = jnp.pad(repeat, ((0, 0), (0, 1)))
683
+
684
+ logprobs_phi = logprobs[:, :, blank_id : blank_id + 1] # [B, T, 1]
685
+ logprobs_phi = jnp.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1]
686
+
687
+ one_hot = jax.nn.one_hot(labels, num_classes=num_classes) # [B, N, K]
688
+ logprobs_emit = jnp.einsum("btk,bnk->btn", logprobs, one_hot)
689
+ logprobs_emit = jnp.transpose(logprobs_emit, (1, 0, 2)) # [T, B, N]
690
+
691
+ logalpha_phi_init = jnp.ones((batchsize, maxlabellen + 1)) * log_epsilon # [B, N]
692
+ logalpha_phi_init = logalpha_phi_init.at[:, 0].set(0.0)
693
+ logalpha_emit_init = jnp.ones((batchsize, maxlabellen)) * log_epsilon # [B, N]
694
+
695
+ def loop_body(prev, x):
696
+ prev_phi, prev_emit = prev
697
+ # emit-to-phi epsilon transition, except if the next label is repetition
698
+ prev_phi_orig = prev_phi
699
+ prev_phi = prev_phi.at[:, 1:].set(jnp.logaddexp(prev_phi[:, 1:], prev_emit + log_epsilon * repeat))
700
+
701
+ logprob_emit, logprob_phi, pad = x
702
+
703
+ # phi-to-emit transition
704
+ next_emit = jnp.logaddexp(prev_phi[:, :-1] + logprob_emit, prev_emit + logprob_emit)
705
+ # self-loop transition
706
+ next_phi = prev_phi + logprob_phi
707
+ # emit-to-phi blank transition only when the next label is repetition
708
+ next_phi = next_phi.at[:, 1:].set(
709
+ jnp.logaddexp(next_phi[:, 1:], prev_emit + logprob_phi + log_epsilon * (1.0 - repeat))
710
+ )
711
+
712
+ pad = pad.reshape((batchsize, 1))
713
+ next_emit = pad * prev_emit + (1.0 - pad) * next_emit
714
+ next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi
715
+
716
+ return (next_phi, next_emit), (next_phi, next_emit)
717
+
718
+ xs = (logprobs_emit, logprobs_phi, logitpaddings.transpose((1, 0)))
719
+ _, (logalpha_phi, logalpha_emit) = jax.lax.scan(loop_body, (logalpha_phi_init, logalpha_emit_init), xs)
720
+
721
+ # last row needs to be updated with the last epsilon transition
722
+ logalpha_phi_last = logalpha_phi[-1].at[:, 1:].set(jnp.logaddexp(logalpha_phi[-1, :, 1:], logalpha_emit[-1]))
723
+ logalpha_phi = logalpha_phi.at[-1].set(logalpha_phi_last)
724
+
725
+ # extract per_seq_loss
726
+ one_hot = jax.nn.one_hot(labellens, num_classes=maxlabellen + 1) # [B, N+1]
727
+ per_seq_loss = -jnp.einsum("bn,bn->b", logalpha_phi_last, one_hot)
728
+
729
+ if loss_reduction == "mean":
730
+ target_lengths = labelpaddings.shape[-1] - labelpaddings.sum(axis=-1)
731
+ loss = (per_seq_loss / target_lengths).mean()
732
+ elif loss_reduction == "sum":
733
+ loss = per_seq_loss.sum()
734
+ else:
735
+ loss = per_seq_loss
736
+
737
+ if not output_emission_dict:
738
+ return loss
739
+
740
+ return loss, {
741
+ "logalpha_phi": logalpha_phi,
742
+ "logalpha_emit": logalpha_emit,
743
+ "logprobs_phi": logprobs_phi,
744
+ "logprobs_emit": logprobs_emit,
745
+ }
746
+
747
+
748
+ def make_dataset(seed=42):
749
+ # Pre-processing dataset
750
+ import re
751
+
752
+ def map_nst(entry):
753
+ text = entry["text"].lower()
754
+ text = text.replace("(...vær stille under dette opptaket...)", "")
755
+ text = re.sub('[áàâ]', 'a', text)
756
+ text = re.sub('[ä]', 'æ', text)
757
+ text = re.sub('[éèëê]', 'e', text)
758
+ text = re.sub('[íìïî]', 'i', text)
759
+ text = re.sub('[óòöô]', 'o', text)
760
+ text = re.sub('[ö]', 'ø', text)
761
+ text = re.sub('[ç]', 'c', text)
762
+ text = re.sub('[úùüû]', 'u', text)
763
+ # text = re.sub('\\(?=(Punktum|Komma|Utropstegn|Spørsmålstegn))', ' ', text)
764
+ text = re.sub('\s+', ' ', text)
765
+ return {"text": text}
766
+
767
+ def filter_nst(entry):
768
+ if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)):
769
+ return False # Too short
770
+ if re.match(entry["type"], "pIW|CA"):
771
+ return False # Spelling out words
772
+ return True
773
+
774
+ def filter_npsc(entry):
775
+ # False if there are digits in the text
776
+ if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)):
777
+ return False # Too short
778
+ if re.search("\d", entry["text"]):
779
+ return False
780
+ return True
781
+
782
+ def map_npsc(entry):
783
+ batch = {"text": entry["text"].lower()}
784
+ batch["text"] = re.sub('[áàâ]', 'a', batch["text"])
785
+ batch["text"] = re.sub('[ä]', 'æ', batch["text"])
786
+ batch["text"] = re.sub('[éèëê]', 'e', batch["text"])
787
+ batch["text"] = re.sub('[íìïî]', 'i', batch["text"])
788
+ batch["text"] = re.sub('[óòöô]', 'o', batch["text"])
789
+ batch["text"] = re.sub('[ö]', 'ø', batch["text"])
790
+ batch["text"] = re.sub('[ç]', 'c', batch["text"])
791
+ batch["text"] = re.sub('[úùüû]', 'u', batch["text"])
792
+ batch["text"] = re.sub('\s', ' ', batch["text"])
793
+ batch["text"] = re.sub('<ee>', 'eee', batch["text"])
794
+ batch["text"] = re.sub('<qq>', 'qqq', batch["text"])
795
+ batch["text"] = re.sub('<mm>', 'mmm', batch["text"])
796
+ batch["text"] = re.sub('<inaudible>', 'xxx', batch["text"])
797
+ # batch["text"] = re.sub('<inaudible>', '?', batch["text"])
798
+ if "<" in batch["text"]:
799
+ raise ValueError(batch["text"])
800
+ return batch
801
+
802
+ nst = datasets.load_dataset("NbAiLab/NST", "no-close")
803
+ npsc = datasets.load_dataset("NbAiLab/NPSC", "16K_mp3")
804
+ # TODO NST_hesitate
805
+
806
+ split = len(npsc["train"]) / (len(npsc["train"]) + len(npsc["validation"])) # Use same train/val ratio as NPSC
807
+ nst_train = nst["train"].train_test_split(train_size=split, seed=seed)
808
+ nst["train"] = nst_train["train"]
809
+ nst["validation"] = nst_train["test"]
810
+
811
+ nst = nst.filter(filter_nst).map(map_nst).shuffle(seed=seed)
812
+ npsc = npsc.filter(filter_npsc).map(map_npsc).shuffle(seed=seed)
813
+
814
+ npsc_base = npsc.remove_columns([col for col in npsc["train"].column_names if col not in ["text", "audio"]])
815
+ nst_base = nst.remove_columns([col for col in nst["train"].column_names if col not in ["text", "audio"]])
816
+
817
+ combined = {}
818
+ for split in "train", "validation", "test":
819
+ probs = np.array([len(nst_base[split]), len(npsc_base[split])]) # Weight by number of examples
820
+ probs = (probs / probs.sum()).tolist()
821
+ comb = datasets.interleave_datasets([nst_base[split], npsc_base[split]], probabilities=probs, seed=seed)
822
+ combined[split] = comb
823
+
824
+ return datasets.DatasetDict(**combined)
825
+
826
+ def main():
827
+ # 1. Parse input arguments
828
+ # See all possible arguments in src/transformers/training_args.py
829
+ # or by passing the --help flag to this script.
830
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
831
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FlaxTrainingArguments))
832
+
833
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
834
+ # If we pass only one argument to the script and it's the path to a json file,
835
+ # let's parse it to get our arguments.
836
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
837
+ else:
838
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
839
+
840
+ # 2. Setup logging
841
+ # Make one log on every process with the configuration for debugging.
842
+ logging.basicConfig(
843
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
844
+ datefmt="%m/%d/%Y %H:%M:%S",
845
+ handlers=[logging.StreamHandler(sys.stdout)],
846
+ )
847
+ # Set the verbosity to info of the Transformers logger.
848
+ # We only want one process per machine to log things on the screen.
849
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
850
+ if jax.process_index() == 0:
851
+ datasets.utils.logging.set_verbosity_warning()
852
+ transformers.utils.logging.set_verbosity_info()
853
+ else:
854
+ datasets.utils.logging.set_verbosity_error()
855
+ transformers.utils.logging.set_verbosity_error()
856
+
857
+ # Set up wandb run
858
+ if jax.process_index() == 0:
859
+ wandb.init(project=data_args.wandb_project, name=data_args.wandb_name, job_type=data_args.wandb_job_type)
860
+
861
+ logger.info("Training/evaluation parameters %s", training_args)
862
+
863
+ # Set the default TPU matmul precision and display the number of devices
864
+ jax.config.update("jax_default_matmul_precision", training_args.matmul_precision)
865
+ logger.info(f"JAX devices: {jax.device_count()}, matmul precision: {training_args.matmul_precision}")
866
+
867
+ # 4. Load dataset
868
+
869
+ set_seed(training_args.seed)
870
+ raw_datasets = make_dataset(seed=training_args.seed)
871
+
872
+ # raw_datasets = DatasetDict()
873
+
874
+ # if training_args.do_train:
875
+ # raw_datasets["train"] = load_dataset(
876
+ # data_args.dataset_name,
877
+ # data_args.dataset_config_name,
878
+ # split=data_args.train_split_name,
879
+ # cache_dir=data_args.dataset_cache_dir,
880
+ # use_auth_token=True if model_args.use_auth_token else None,
881
+ # )
882
+
883
+ # if training_args.do_eval:
884
+ # raw_datasets["eval"] = load_dataset(
885
+ # data_args.dataset_name,
886
+ # data_args.dataset_config_name,
887
+ # split=data_args.eval_split_name,
888
+ # cache_dir=data_args.dataset_cache_dir,
889
+ # use_auth_token=True if model_args.use_auth_token else None,
890
+ # )
891
+
892
+ # if training_args.do_predict:
893
+ # test_split = data_args.test_split_name.split("+")
894
+ # for split in test_split:
895
+ # raw_datasets[split] = load_dataset(
896
+ # data_args.dataset_name,
897
+ # data_args.dataset_config_name,
898
+ # split=split,
899
+ # cache_dir=data_args.dataset_cache_dir,
900
+ # use_auth_token=True if model_args.use_auth_token else None,
901
+ # )
902
+
903
+ if not training_args.do_train and not training_args.do_eval and not training_args.do_predict:
904
+ raise ValueError(
905
+ "Cannot not train, not do evaluation and not do prediction. At least one of "
906
+ "training, evaluation or prediction has to be done."
907
+ )
908
+
909
+ # if not training, there is no need to run multiple epochs
910
+ if not training_args.do_train:
911
+ training_args.num_train_epochs = 1
912
+
913
+ if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
914
+ raise ValueError(
915
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
916
+ "Make sure to set `--audio_column_name` to the correct audio column - one of "
917
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
918
+ )
919
+
920
+ if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names:
921
+ raise ValueError(
922
+ f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
923
+ "Make sure to set `--text_column_name` to the correct text column - one of "
924
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
925
+ )
926
+
927
+ # 5. Load pretrained model, tokenizer, and feature extractor
928
+ #
929
+ # Distributed training:
930
+ # The .from_pretrained methods guarantee that only one local process can concurrently
931
+ config = Wav2Vec2Config.from_pretrained(
932
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
933
+ cache_dir=model_args.cache_dir,
934
+ revision=model_args.model_revision,
935
+ use_auth_token=True if model_args.use_auth_token else None,
936
+ )
937
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
938
+ model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
939
+ cache_dir=model_args.cache_dir,
940
+ revision=model_args.model_revision,
941
+ use_auth_token=True if model_args.use_auth_token else None,
942
+ )
943
+ tokenizer = AutoTokenizer.from_pretrained(
944
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
945
+ cache_dir=model_args.cache_dir,
946
+ revision=model_args.model_revision,
947
+ use_auth_token=True if model_args.use_auth_token else None,
948
+ )
949
+ # update config according to training args, model args, and tokenizer attributes
950
+ config.update(
951
+ {
952
+ "feat_proj_dropout": model_args.feat_proj_dropout,
953
+ "attention_dropout": model_args.attention_dropout,
954
+ "hidden_dropout": model_args.hidden_dropout,
955
+ "final_dropout": model_args.final_dropout,
956
+ "mask_time_prob": model_args.mask_time_prob,
957
+ "mask_time_length": model_args.mask_time_length,
958
+ "mask_feature_prob": model_args.mask_feature_prob,
959
+ "mask_feature_length": model_args.mask_feature_length,
960
+ "gradient_checkpointing": training_args.gradient_checkpointing,
961
+ "layerdrop": model_args.layerdrop,
962
+ "ctc_loss_reduction": model_args.ctc_loss_reduction,
963
+ "ctc_zero_infinity": model_args.ctc_zero_infinity,
964
+ "pad_token_id": tokenizer.pad_token_id,
965
+ "vocab_size": tokenizer.vocab_size, # len(tokenizer),
966
+ "activation_dropout": model_args.activation_dropout,
967
+ }
968
+ )
969
+
970
+ if tokenizer.do_lower_case and data_args.dataset_name != "librispeech_asr":
971
+ raise ValueError(
972
+ "Setting the tokenizer attribute `do_lower_case` to `True` converts all input strings to "
973
+ "uppercase prior to tokenization. This should only be done when the tokenizer is built on an uppercased corpus,"
974
+ "i.e. for the dataset `librispeech_asr` only. If your dataset is not `librispeech_asr`, the tokenizer is mostly likely "
975
+ "built on an lowercased corpus. In this case, set `tokenizer.do_lower_case` to ``False`."
976
+ )
977
+
978
+ if training_args.precision == "full_mixed":
979
+ dtype = jnp.bfloat16
980
+ training_args.mixed_precision = True
981
+ elif training_args.precision == "half_mixed":
982
+ dtype = jnp.bfloat16
983
+ training_args.mixed_precision = False
984
+ else:
985
+ dtype = jnp.float32
986
+ training_args.mixed_precision = False
987
+
988
+ model = FlaxWav2Vec2ForCTC.from_pretrained(
989
+ model_args.model_name_or_path,
990
+ config=config,
991
+ dtype=dtype,
992
+ cache_dir=model_args.cache_dir,
993
+ revision=model_args.model_revision,
994
+ use_auth_token=True if model_args.use_auth_token else None,
995
+ )
996
+
997
+ # 6. Resample speech dataset ALWAYS
998
+ raw_datasets = raw_datasets.cast_column(
999
+ data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
1000
+ )
1001
+
1002
+ # 7. Preprocessing the datasets.
1003
+ # We need to read the audio files as arrays and tokenize the targets.
1004
+ max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
1005
+ min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate)
1006
+ max_target_length = data_args.max_label_length
1007
+ min_target_length = data_args.min_label_length
1008
+ pad_input_to_multiple_of = data_args.pad_input_to_multiple_of
1009
+ audio_column_name = data_args.audio_column_name
1010
+ num_workers = data_args.preprocessing_num_workers
1011
+ text_column_name = data_args.text_column_name
1012
+ model_input_name = feature_extractor.model_input_names[0]
1013
+ do_lower_case = data_args.do_lower_case
1014
+ dataset_name = data_args.dataset_name
1015
+ chars_to_ignore = ', ? . ! - ; : " “ % ‘ ” ?'.split(" ")
1016
+ chars_to_ignore_regex = f'[{"".join(chars_to_ignore)}]'
1017
+ # gigaspeech_punctuation = {" <comma>": ",", " <period>": ".", " <questionmark>": "?", " <exclamationpoint>": "!"}
1018
+ # gigaspeech_disfluencies = ["<other>", "<sil>"]
1019
+ # swb_disfluencies = ["[noise]", "[laughter]", "[silence]", "<a_aside>", "<b_aside>", "<e_aside>", "[laughter-",
1020
+ # "[vocalized-noise]", "_1"]
1021
+ # swb_punctuations = ["{", "}", "[", "]-", "]"]
1022
+ # earnings_disfluencies = ["<crosstalk>", "<affirmative>", "<inaudible>", "inaudible", "<laugh>", "<unk>"]
1023
+ ignore_segments = ["ignore_time_segment_in_scoring", "<noise>", "<music>", "[noise]", "[laughter]", "[silence]",
1024
+ "[vocalized-noise]", "<crosstalk>", "<affirmative>", "<inaudible>", "<laugh>", "<other>", "<sil>", ""]
1025
+
1026
+ if training_args.do_train and data_args.max_train_samples is not None:
1027
+ raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
1028
+
1029
+ if training_args.do_eval and data_args.max_eval_samples is not None:
1030
+ raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
1031
+
1032
+ if training_args.do_predict and data_args.max_test_samples is not None:
1033
+ for split in test_split:
1034
+ raw_datasets[split] = raw_datasets[split].select(range(data_args.max_eval_samples))
1035
+
1036
+ if training_args.do_train and data_args.remove_punctuation:
1037
+
1038
+ def remove_punctuation(batch):
1039
+ batch[text_column_name] = (
1040
+ re.sub(chars_to_ignore_regex, "", batch[text_column_name]).replace("'", "").replace('"', "")
1041
+ )
1042
+
1043
+ raw_datasets["train"] = raw_datasets["train"].map(
1044
+ remove_punctuation,
1045
+ num_proc=data_args.preprocessing_num_workers,
1046
+ desc="removing punctuation from train split",
1047
+ )
1048
+
1049
+ # filter data where the targets are ignored in scoring
1050
+ def is_target_labels(input_str):
1051
+ return input_str.lower() not in ignore_segments
1052
+
1053
+ raw_datasets = raw_datasets.filter(
1054
+ is_target_labels,
1055
+ num_proc=num_workers,
1056
+ input_columns=[text_column_name],
1057
+ desc="filtering data where the targets are ignored in scoring",
1058
+ )
1059
+
1060
+ def prepare_dataset(batch):
1061
+ # process audio
1062
+ try:
1063
+ sample = batch[audio_column_name]
1064
+ except ValueError:
1065
+ sample = {"array": np.array([0.]), "sampling_rate": feature_extractor.sampling_rate}
1066
+ inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
1067
+ # process audio length
1068
+ batch[model_input_name] = inputs.input_values[0]
1069
+ batch["input_length"] = len(batch["input_values"])
1070
+
1071
+ # process targets
1072
+ input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
1073
+
1074
+ # if dataset_name == "google/xtreme_s":
1075
+ # # Finally, we tokenize the processed text
1076
+ # batch["labels"] = tokenizer(input_str).input_ids
1077
+ # batch["labels_length"] = len(batch["labels"])
1078
+ # return batch
1079
+
1080
+ # # Common Voice 9
1081
+ # if input_str.startswith('"') and input_str.endswith('"'):
1082
+ # # we can remove trailing quotation marks as they do not affect the transcription
1083
+ # input_str = input_str[1:-1]
1084
+ # # normalize quotation marks
1085
+ # input_str = re.sub(r'["“”]', '"', input_str)
1086
+ # # normalize apostrophes
1087
+ # input_str = re.sub(r"[’']", "'", input_str)
1088
+ # # normalize hyphens
1089
+ # input_str = re.sub(r"[—–]", "-", input_str)
1090
+ # # replace double quotation marks with single
1091
+ # input_str = input_str.replace('""', '"')
1092
+ # if dataset_name == "mozilla-foundation/common_voice_9_0" and len(input_str):
1093
+ # # for CV9, we'll normalize the text to always finish with punctuation
1094
+ # if input_str[-1] not in [".", "?", "!"]:
1095
+ # input_str = input_str + "."
1096
+
1097
+ # # TEDLIUM-3
1098
+ # # delete the <unk> token from the text and replace spaced apostrophes with un-spaced
1099
+ # input_str = input_str.replace("<unk>", "").replace(" '", "'")
1100
+
1101
+ # # GigaSpeech
1102
+ # for disfluency in gigaspeech_disfluencies:
1103
+ # input_str = input_str.replace(disfluency, "")
1104
+ # # convert spelled out punctuation to symbolic form
1105
+ # for punctuation, replacement in gigaspeech_punctuation.items():
1106
+ # input_str = input_str.replace(punctuation, replacement)
1107
+ # if dataset_name == "speechcolab/gigaspeech" and len(input_str):
1108
+ # # for GS, we'll normalize the text to always finish with punctuation
1109
+ # if input_str[-1] not in [".", "?", "!"]:
1110
+ # input_str = input_str + "."
1111
+
1112
+ # # SWB
1113
+ # for disfluency in swb_disfluencies:
1114
+ # input_str = input_str.replace(disfluency, "")
1115
+ # # remove parenthesised text (test data only)
1116
+ # input_str = re.sub("[\(].*?[\)]", "", input_str)
1117
+ # for punctuation in swb_punctuations:
1118
+ # input_str = input_str.replace(punctuation, "")
1119
+ # # replace anomalous words with their correct transcriptions
1120
+ # split_str = input_str.split("/")
1121
+ # if len(split_str) > 1:
1122
+ # input_str = " ".join(
1123
+ # [" ".join([" ".join(i.split(" ")[:-1]) for i in split_str])] + [split_str[-1].split(" ")[-1]])
1124
+
1125
+ # # Earnings 22
1126
+ # for disfluency in earnings_disfluencies:
1127
+ # input_str = input_str.replace(disfluency, "")
1128
+ # # replace mal-formatted ellipsis
1129
+ # input_str = input_str.replace("…", ".")
1130
+
1131
+ # JIWER compliance
1132
+ # remove multiple spaces
1133
+ input_str = re.sub(r"\s\s+", " ", input_str)
1134
+ # strip trailing spaces
1135
+ input_str = input_str.strip()
1136
+
1137
+ # Finally, we tokenize the processed text
1138
+ batch["labels"] = tokenizer(input_str).input_ids
1139
+ batch["labels_length"] = len(batch["labels"])
1140
+ return batch
1141
+
1142
+ vectorized_datasets = raw_datasets.map(
1143
+ prepare_dataset,
1144
+ remove_columns=next(iter(raw_datasets.values())).column_names,
1145
+ num_proc=num_workers,
1146
+ desc="preprocess dataset",
1147
+ )
1148
+
1149
+ # filter data with inputs shorter than min_input_length or longer than max_input_length
1150
+ def is_audio_in_length_range(length):
1151
+ return length > min_input_length and length < max_input_length
1152
+
1153
+ vectorized_datasets = vectorized_datasets.filter(
1154
+ is_audio_in_length_range,
1155
+ num_proc=num_workers,
1156
+ input_columns=["input_length"],
1157
+ )
1158
+
1159
+ # filter data with targets shorter than min_target_length or longer than max_target_length
1160
+ def is_labels_in_length_range(length):
1161
+ return length > min_target_length # and length < max_target_length
1162
+
1163
+ vectorized_datasets = vectorized_datasets.filter(
1164
+ is_labels_in_length_range,
1165
+ num_proc=num_workers,
1166
+ input_columns=["labels_length"],
1167
+ )
1168
+
1169
+ # for large datasets it is advised to run the preprocessing on a
1170
+ # single machine first with `args.preprocessing_only` since there will mostly likely
1171
+ # be a timeout when running the script in distributed mode.
1172
+ # In a second step `args.preprocessing_only` can then be set to `False` to load the
1173
+ # cached dataset
1174
+ if data_args.preprocessing_only:
1175
+ cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
1176
+ logger.info(f"Data preprocessing finished. Files cached at {cache}.")
1177
+ return
1178
+
1179
+ # 8. Load Metrics
1180
+ wer_metric = load_metric("wer")
1181
+ cer_metric = load_metric("cer")
1182
+
1183
+ def compute_metrics(pred_ids: List[List[int]], label_ids: List[List[int]]):
1184
+ padded_ids = np.where(np.asarray(label_ids) == -100, tokenizer.pad_token_id, np.asarray(label_ids))
1185
+
1186
+ pred_str = tokenizer.batch_decode(pred_ids)
1187
+ # we do not want to group tokens when computing the metrics
1188
+ label_str = tokenizer.batch_decode(padded_ids, group_tokens=False)
1189
+
1190
+ wer = wer_metric.compute(predictions=pred_str, references=label_str)
1191
+ cer = cer_metric.compute(predictions=pred_str, references=label_str)
1192
+
1193
+ return {"wer": wer, "cer": cer}, pred_str, label_str
1194
+
1195
+ # 9. save feature extractor, tokenizer and config
1196
+ feature_extractor.save_pretrained(training_args.output_dir)
1197
+ tokenizer.save_pretrained(training_args.output_dir)
1198
+ config.save_pretrained(training_args.output_dir)
1199
+
1200
+ processor = AutoProcessor.from_pretrained(training_args.output_dir)
1201
+
1202
+ data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding(
1203
+ processor=processor,
1204
+ input_padding="longest",
1205
+ pad_input_to_multiple_of=pad_input_to_multiple_of,
1206
+ max_label_length=data_args.max_label_length,
1207
+ )
1208
+
1209
+ # Enable tensorboard only on the master node
1210
+ has_tensorboard = is_tensorboard_available()
1211
+ if has_tensorboard and jax.process_index() == 0:
1212
+ try:
1213
+ from flax.metrics.tensorboard import SummaryWriter
1214
+
1215
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
1216
+ except ImportError as ie:
1217
+ has_tensorboard = False
1218
+ logger.warning(
1219
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
1220
+ )
1221
+ else:
1222
+ logger.warning(
1223
+ "Unable to display metrics through TensorBoard because the package is not installed: "
1224
+ "Please run `pip install tensorboard` to enable."
1225
+ )
1226
+
1227
+ # 10. Handle the repository creation
1228
+ if training_args.push_to_hub:
1229
+ with open(os.path.join(training_args.output_dir, ".gitattributes"), "r+") as f:
1230
+ git_lfs_extensions = f.read()
1231
+ if "*.wandb" not in git_lfs_extensions:
1232
+ f.write("*.wandb filter=lfs diff=lfs merge=lfs -text")
1233
+ if training_args.hub_model_id is None:
1234
+ repo_name = get_full_repo_name(
1235
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
1236
+ )
1237
+ else:
1238
+ repo_name = training_args.hub_model_id
1239
+ repo = Repository(training_args.output_dir, clone_from=repo_name)
1240
+
1241
+ # 11. Initialize our training
1242
+ rng = jax.random.PRNGKey(training_args.seed)
1243
+ rng, dropout_rng = jax.random.split(rng)
1244
+
1245
+ # Store some constants
1246
+ max_steps = int(training_args.max_steps)
1247
+ gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
1248
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
1249
+ batch_size_per_update = train_batch_size * gradient_accumulation_steps
1250
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
1251
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
1252
+ to_dtype = to_bf16 if training_args.mixed_precision else to_fp32
1253
+
1254
+ if training_args.do_train:
1255
+ num_train_samples = len(vectorized_datasets["train"])
1256
+ steps_per_epoch = num_train_samples // batch_size_per_update
1257
+ if max_steps > 0:
1258
+ num_epochs = -(training_args.max_steps // -steps_per_epoch)
1259
+ total_train_steps = max_steps
1260
+ else:
1261
+ num_epochs = int(training_args.num_train_epochs)
1262
+ total_train_steps = steps_per_epoch * num_epochs
1263
+
1264
+ # Create learning rate schedule
1265
+ # Create learning rate schedule
1266
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
1267
+ total_train_steps,
1268
+ training_args.warmup_steps,
1269
+ training_args.learning_rate,
1270
+ )
1271
+
1272
+ # We use Optax's "masking" functionality to not apply weight decay
1273
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
1274
+ # mask boolean with the same structure as the parameters.
1275
+ # The mask is True for parameters that should be decayed.
1276
+ # Note that this mask is specifically adapted for FlaxWav2Vec2 and FlaxBart.
1277
+ # For FlaxT5, one should correct the layer norm parameter naming
1278
+ # accordingly - see `run_t5_mlm_flax.py` e.g.
1279
+ def decay_mask_fn(params):
1280
+ flat_params = traverse_util.flatten_dict(params)
1281
+ layer_norm_params = [
1282
+ (name, "scale")
1283
+ for name in ["layer_norm", "self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
1284
+ ]
1285
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
1286
+ return traverse_util.unflatten_dict(flat_mask)
1287
+
1288
+ if training_args.adafactor:
1289
+ # Create Adafactor optimizer
1290
+ optim = optax.adafactor(
1291
+ learning_rate=linear_decay_lr_schedule_fn,
1292
+ dtype_momentum=jnp.bfloat16 if training_args.mixed_precision else jnp.float32,
1293
+ weight_decay_rate=training_args.weight_decay,
1294
+ weight_decay_mask=decay_mask_fn,
1295
+ )
1296
+ else:
1297
+ # Create AdamW optimizer
1298
+ optim = optax.adamw(
1299
+ learning_rate=linear_decay_lr_schedule_fn,
1300
+ b1=training_args.adam_beta1,
1301
+ b2=training_args.adam_beta2,
1302
+ eps=training_args.adam_epsilon,
1303
+ weight_decay=training_args.weight_decay,
1304
+ mask=decay_mask_fn,
1305
+ )
1306
+
1307
+ # Optax MultiSteps for gradient accumulation. We'll only call this optimizer transformation if gradient accumulation is required (i.e. gradient accumulation steps > 1)
1308
+ if training_args.multisteps and gradient_accumulation_steps > 1:
1309
+ optim = optax.MultiSteps(optim, gradient_accumulation_steps, use_grad_mean=False)
1310
+ else:
1311
+ num_epochs = 0
1312
+ total_train_steps = 0
1313
+ num_train_samples = 0
1314
+ optim = None
1315
+
1316
+ # Setup train state
1317
+ state = MixedPrecisionTrainState.create(
1318
+ apply_fn=model.__call__,
1319
+ get_attention_mask_fn=model._get_feature_vector_attention_mask,
1320
+ params=model.params,
1321
+ tx=optim,
1322
+ to_dtype=to_dtype,
1323
+ dropout_rng=dropout_rng,
1324
+ max_grad_norm=training_args.max_grad_norm,
1325
+ )
1326
+
1327
+ # Replicate the train state on each device
1328
+ state = state.replicate()
1329
+ blank_id = model.config.pad_token_id
1330
+
1331
+ # Define gradient update step fn
1332
+ def train_step(state, batch):
1333
+ # only one single rng per grad step, with or without accumulation, as the graph should be identical over one effective training batch
1334
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
1335
+
1336
+ def compute_loss(params, minibatch):
1337
+ labels = minibatch.pop("labels")
1338
+ logits = state.apply_fn(
1339
+ **minibatch,
1340
+ params=params,
1341
+ dropout_rng=dropout_rng,
1342
+ freeze_feature_encoder=model_args.freeze_feature_encoder,
1343
+ train=True,
1344
+ )[0]
1345
+ logits_mask = state.get_attention_mask_fn(logits.shape[1], batch["attention_mask"])
1346
+ loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean")
1347
+
1348
+ return loss
1349
+
1350
+ grad_fn = jax.value_and_grad(compute_loss)
1351
+
1352
+ if gradient_accumulation_steps == 1 or training_args.multisteps:
1353
+ loss, grad = grad_fn(to_dtype(state.params), batch)
1354
+
1355
+ # Custom gradient accumulation
1356
+ else:
1357
+ # add a first dimension over gradient_accumulation_steps for minibatch slices
1358
+ batch = jax.tree_map(
1359
+ lambda x: x.reshape(
1360
+ gradient_accumulation_steps, training_args.per_device_train_batch_size, *x.shape[1::]
1361
+ ),
1362
+ batch,
1363
+ )
1364
+
1365
+ def accum_minibatch_step(accum_grad, minibatch):
1366
+ # compute loss, num labels and grad over minibatch and accumulate
1367
+ loss, grad = grad_fn(to_dtype(state.params), minibatch)
1368
+ return jax.tree_map(jnp.add, accum_grad, grad), loss
1369
+
1370
+ # create an initial state for accumulating losses, num labels and gradients
1371
+ init_grad = jax.tree_map(jnp.zeros_like, to_dtype(state.params))
1372
+ # loop accum minibatch step over the number of gradient accumulation steps
1373
+ grad, loss = jax.lax.scan(accum_minibatch_step, init_grad, batch)
1374
+
1375
+ # update state
1376
+ new_state = state.apply_gradients(
1377
+ grads=grad,
1378
+ dropout_rng=new_dropout_rng,
1379
+ to_dtype=to_dtype,
1380
+ )
1381
+
1382
+ # compute gradient norms over all layers and globally for detailed monitoring
1383
+ layer_grad_norm = jax.tree_map(jnp.linalg.norm, grad)
1384
+ logs = {
1385
+ "layer_grad_norm": layer_grad_norm,
1386
+ "grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm)),
1387
+ }
1388
+
1389
+ # compute parameter norms over all layers and globally for detailed monitoring
1390
+ layer_param_norm = jax.tree_map(jnp.linalg.norm, new_state.params)
1391
+ logs["layer_param_norm"] = layer_param_norm
1392
+ logs["param_norm"] = jnp.linalg.norm(jax.tree_util.tree_leaves(layer_param_norm))
1393
+
1394
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
1395
+ metrics.update(logs)
1396
+
1397
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1398
+ # metrics = to_fp32(metrics)
1399
+
1400
+ return new_state, metrics
1401
+
1402
+ # Define eval fn
1403
+ def eval_step(params, batch):
1404
+ labels = batch.pop("labels")
1405
+ logits = model(**batch, params=params, train=False)[0]
1406
+
1407
+ logits_mask = model._get_feature_vector_attention_mask(logits.shape[1], batch["attention_mask"])
1408
+ loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean")
1409
+
1410
+ pred_ids = jnp.argmax(logits, axis=-1)
1411
+
1412
+ # summarize metrics
1413
+ metrics = {"loss": loss}
1414
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1415
+ # metrics = to_fp32(metrics)
1416
+ return metrics, pred_ids
1417
+
1418
+ # Create parallel version of the train and eval step
1419
+ if training_args.do_train:
1420
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
1421
+
1422
+ if training_args.do_eval:
1423
+ p_eval_step = jax.pmap(eval_step, "batch")
1424
+
1425
+ def run_evaluation(step):
1426
+ if training_args.do_eval:
1427
+ # ======================== Evaluating ==============================
1428
+ eval_metrics = []
1429
+ eval_preds = []
1430
+ eval_labels = []
1431
+
1432
+ # Generate eval set by sequentially sampling indices from the eval dataset and grouping by length
1433
+ eval_samples_idx = get_grouped_indices(vectorized_datasets["eval"], eval_batch_size)
1434
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
1435
+
1436
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
1437
+ samples = [vectorized_datasets["eval"][int(idx)] for idx in batch_idx]
1438
+ batch = data_collator(samples)
1439
+ labels = batch["labels"]
1440
+
1441
+ metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1442
+ eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1])))
1443
+ eval_metrics.append(metrics)
1444
+
1445
+ eval_labels.extend(labels)
1446
+
1447
+ # normalize eval metrics
1448
+ eval_metrics = get_metrics(eval_metrics)
1449
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
1450
+ eval_metrics = to_fp32(eval_metrics)
1451
+
1452
+ # always run compute metrics
1453
+ error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels)
1454
+ eval_metrics.update(error_rate_metric)
1455
+ error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()])
1456
+
1457
+ # Print metrics and update progress bar
1458
+ desc = f"Step... ({step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})"
1459
+ epochs.write(desc)
1460
+ epochs.desc = desc
1461
+
1462
+ # Save metrics
1463
+ write_wandb_log(eval_metrics, step, prefix="eval")
1464
+ write_wandb_pred(pred_str, label_str, step)
1465
+ # if has_tensorboard and jax.process_index() == 0:
1466
+ # write_eval_metric(summary_writer, eval_metrics, step, pred_str=pred_str)
1467
+
1468
+ def save_checkpoint(step):
1469
+ # save and push checkpoint to the hub
1470
+ if jax.process_index() == 0:
1471
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
1472
+ model.save_pretrained(training_args.output_dir, params=params)
1473
+ tokenizer.save_pretrained(training_args.output_dir)
1474
+ if training_args.push_to_hub:
1475
+ repo.push_to_hub(commit_message=f"{wandb.run.id}: saving weights and logs of step {int(step / 1000)}k", blocking=False)
1476
+
1477
+ logger.info("***** Running training *****")
1478
+ logger.info(f" Num examples = {num_train_samples}")
1479
+ logger.info(f" Num Epochs = {num_epochs}")
1480
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
1481
+ logger.info(f" Num gradient accumulation steps = {gradient_accumulation_steps}")
1482
+ logger.info(f" Total train batch size (w. parallel & distributed) = {batch_size_per_update}")
1483
+ logger.info(f" Total optimization steps = {total_train_steps}")
1484
+ logger.info(f" Gradient checkpointing: {config.gradient_checkpointing}")
1485
+ logger.info(f" Use scan: {config.use_scan}")
1486
+ logger.info(f" Fuse matmuls: {config.fuse_matmuls}")
1487
+
1488
+ train_time = cur_step = 0
1489
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
1490
+ for epoch in epochs:
1491
+ if training_args.do_train:
1492
+ # ======================== Training ================================
1493
+ train_start = time.time()
1494
+
1495
+ # Create sampling rng
1496
+ rng, input_rng = jax.random.split(rng)
1497
+
1498
+ # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length
1499
+ train_samples_idx = get_grouped_indices(vectorized_datasets["train"], batch_size_per_update, input_rng)
1500
+ train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update)
1501
+
1502
+ # Gather the indices for creating the batch and do a training step
1503
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1), 1):
1504
+ samples = [vectorized_datasets["train"][int(idx)] for idx in batch_idx]
1505
+ batch = data_collator(samples)
1506
+ batch = shard(batch.data)
1507
+ try:
1508
+ state, train_metric = p_train_step(state, batch)
1509
+ except TypeError as e:
1510
+ logger.warning("Encountered following error: \n", e)
1511
+
1512
+ cur_step = epoch * (num_train_samples // batch_size_per_update) + step
1513
+
1514
+ if cur_step % training_args.logging_steps == 0:
1515
+ # Save metrics
1516
+ train_metric = unreplicate(train_metric)
1517
+ train_time += time.time() - train_start
1518
+ # need to upcast all device arrays to fp32 for wandb logging (jnp.bfloat16 not supported) -> do this here OR in train_step
1519
+ write_wandb_log(to_fp32(train_metric), cur_step, prefix="train")
1520
+ # we won't log to tensorboard for now (it is fiddly logging param and grad norms on a layer-by-layer basis)
1521
+ # if has_tensorboard and jax.process_index() == 0:
1522
+ # write_train_metric(summary_writer, train_metrics, train_time, cur_step)
1523
+
1524
+ epochs.write(
1525
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']}, Gradient Norm: {train_metric['grad_norm']})"
1526
+ )
1527
+
1528
+ if cur_step % total_train_steps == 0:
1529
+ break
1530
+
1531
+ if training_args.eval_steps and cur_step % training_args.eval_steps == 0:
1532
+ run_evaluation(cur_step)
1533
+
1534
+ if cur_step % training_args.save_steps == 0:
1535
+ save_checkpoint(cur_step)
1536
+
1537
+ if training_args.eval_steps == 0 and (epoch + 1) != num_epochs:
1538
+ # run evaluation at the end of the epoch if eval steps are not specified
1539
+ run_evaluation(cur_step)
1540
+ save_checkpoint(cur_step)
1541
+
1542
+ if training_args.do_train:
1543
+ save_checkpoint(cur_step)
1544
+
1545
+ cur_step = max_steps if max_steps > 0 else cur_step # set step to max steps so that eval happens in alignment with training
1546
+
1547
+ if training_args.do_eval:
1548
+ run_evaluation(cur_step)
1549
+
1550
+ # TODO: collapse 'do_predict' into the run_evaluation function
1551
+ if training_args.do_predict:
1552
+ for split in test_split:
1553
+ # ======================== Evaluating ==============================
1554
+ eval_metrics = []
1555
+ eval_preds = []
1556
+ eval_labels = []
1557
+
1558
+ # Generate eval set by sequentially sampling indices from the test dataset and grouping by length
1559
+ eval_samples_idx = get_grouped_indices(vectorized_datasets[split], eval_batch_size)
1560
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
1561
+
1562
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc=f"Predicting {split}...", position=2)):
1563
+ samples = [vectorized_datasets[split][int(idx)] for idx in batch_idx]
1564
+ batch = data_collator(samples)
1565
+ labels = batch["labels"]
1566
+
1567
+ metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1568
+ eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1])))
1569
+ eval_metrics.append(metrics)
1570
+
1571
+ eval_labels.extend(labels)
1572
+
1573
+ # normalize eval metrics
1574
+ eval_metrics = get_metrics(eval_metrics)
1575
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
1576
+ eval_metrics = to_fp32(eval_metrics)
1577
+
1578
+ # always run compute metrics
1579
+ error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels)
1580
+ eval_metrics.update(error_rate_metric)
1581
+ error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()])
1582
+
1583
+ # Print metrics and update progress bar
1584
+ desc = f"Step... ({cur_step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})"
1585
+ epochs.write(desc)
1586
+ epochs.desc = desc
1587
+
1588
+ # Save metrics
1589
+ write_wandb_log(eval_metrics, cur_step, prefix=split)
1590
+ write_wandb_pred(pred_str, label_str, cur_step, prefix=split)
1591
+ # if has_tensorboard and jax.process_index() == 0:
1592
+ # write_eval_metric(summary_writer, eval_metrics, cur_step, pred_str=pred_str)
1593
+
1594
+
1595
+ if __name__ == "__main__":
1596
+ main()
wandb/run-20220729_183213-356uc50u/files/config.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_version: 1
2
+
3
+ _wandb:
4
+ desc: null
5
+ value:
6
+ cli_version: 0.12.9
7
+ code_path: code/run_flax_speech_recognition_ctc.py
8
+ framework: huggingface
9
+ huggingface_version: 4.21.0
10
+ is_jupyter_run: false
11
+ is_kaggle_kernel: false
12
+ python_version: 3.8.10
13
+ start_time: 1659119533
14
+ t:
15
+ 1:
16
+ - 1
17
+ - 2
18
+ - 3
19
+ - 11
20
+ - 12
21
+ 2:
22
+ - 1
23
+ - 2
24
+ - 3
25
+ - 11
26
+ - 12
27
+ 3:
28
+ - 13
29
+ 4: 3.8.10
30
+ 5: 0.12.9
31
+ 6: 4.21.0
32
+ 8:
33
+ - 5
wandb/run-20220729_183213-356uc50u/files/output.log ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ INFO:__main__:Training/evaluation parameters FlaxTrainingArguments(
2
+ _n_gpu=0,
3
+ adafactor=False,
4
+ adam_beta1=0.9,
5
+ adam_beta2=0.999,
6
+ adam_epsilon=1e-08,
7
+ auto_find_batch_size=False,
8
+ bf16=False,
9
+ bf16_full_eval=False,
10
+ data_seed=None,
11
+ dataloader_drop_last=False,
12
+ dataloader_num_workers=0,
13
+ dataloader_pin_memory=True,
14
+ ddp_bucket_cap_mb=None,
15
+ ddp_find_unused_parameters=None,
16
+ debug=[],
17
+ deepspeed=None,
18
+ disable_tqdm=False,
19
+ do_eval=True,
20
+ do_predict=False,
21
+ do_train=True,
22
+ eval_accumulation_steps=None,
23
+ eval_delay=0,
24
+ eval_steps=500,
25
+ evaluation_strategy=steps,
26
+ fp16=False,
27
+ fp16_backend=auto,
28
+ fp16_full_eval=False,
29
+ fp16_opt_level=O1,
30
+ fsdp=[],
31
+ fsdp_min_num_params=0,
32
+ fsdp_transformer_layer_cls_to_wrap=None,
33
+ full_determinism=False,
34
+ gradient_accumulation_steps=1,
35
+ gradient_checkpointing=True,
36
+ greater_is_better=None,
37
+ group_by_length=True,
38
+ half_precision_backend=auto,
39
+ hub_model_id=NbAiLab/wav2vec2-1b-npsc-nst,
40
+ hub_private_repo=False,
41
+ hub_strategy=every_save,
42
+ hub_token=<HUB_TOKEN>,
43
+ ignore_data_skip=False,
44
+ include_inputs_for_metrics=False,
45
+ jit_mode_eval=False,
46
+ label_names=None,
47
+ label_smoothing_factor=0.0,
48
+ learning_rate=0.0001,
49
+ length_column_name=input_length,
50
+ load_best_model_at_end=False,
51
+ local_rank=-1,
52
+ log_level=-1,
53
+ log_level_replica=-1,
54
+ log_on_each_node=True,
55
+ logging_dir=./runs/Jul29_18-32-09_t1v-n-eedfb410-w-0,
56
+ logging_first_step=False,
57
+ logging_nan_inf_filter=True,
58
+ logging_steps=100,
59
+ logging_strategy=steps,
60
+ lr_scheduler_type=linear,
61
+ matmul_precision=default,
62
+ max_grad_norm=1.0,
63
+ max_steps=-1,
64
+ metric_for_best_model=None,
65
+ mp_parameters=,
66
+ multisteps=False,
67
+ no_cuda=False,
68
+ num_train_epochs=40.0,
69
+ optim=adamw_hf,
70
+ output_dir=./,
71
+ overwrite_output_dir=True,
72
+ past_index=-1,
73
+ per_device_eval_batch_size=32,
74
+ per_device_train_batch_size=32,
75
+ precision=full,
76
+ prediction_loss_only=False,
77
+ push_to_hub=True,
78
+ push_to_hub_model_id=None,
79
+ push_to_hub_organization=None,
80
+ push_to_hub_token=<PUSH_TO_HUB_TOKEN>,
81
+ ray_scope=last,
82
+ remove_unused_columns=True,
83
+ report_to=['tensorboard', 'wandb'],
84
+ resume_from_checkpoint=None,
85
+ run_name=./,
86
+ save_on_each_node=False,
87
+ save_steps=500,
88
+ save_strategy=steps,
89
+ save_total_limit=3,
90
+ seed=42,
91
+ sharded_ddp=[],
92
+ skip_memory_metrics=True,
93
+ tf32=None,
94
+ torchdynamo=None,
95
+ tpu_metrics_debug=False,
96
+ tpu_num_cores=None,
97
+ use_ipex=False,
98
+ use_legacy_prediction_loop=False,
99
+ warmup_ratio=0.0,
100
+ warmup_steps=4000,
101
+ weight_decay=0.0,
102
+ xpu_backend=None,
103
+ )
104
+ INFO:__main__:JAX devices: 8, matmul precision: default
105
+ Downloading and preparing dataset nst/no-close to /home/javierr/.cache/huggingface/datasets/NbAiLab___nst/no-close/1.0.0/c9a1b1da598ea4a1b584c09ff0e7b0e06974f08bd0329959417147f3f5866f53...
106
+ Downloading builder script: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13.1k/13.1k [00:00<00:00, 154kB/s]
107
+ Downloading data files: 0%| | 0/9 [00:00<?, ?it/s]
108
+
109
+ Downloading data: 98%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 32.6M/33.4M [00:01<00:00, 33.4MB/s]
110
+
111
+
112
+
113
+
114
+
115
+
116
+
117
+
118
+
119
+
120
+
121
+
122
+
123
+
124
+ Downloading data files: 11%|████████████████████████▊ | 1/9 [00:34<04:39, 34.99s/it]
125
+
126
+ Downloading data: 88%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 29.2M/33.4M [00:01<00:00, 40.0MB/s]
127
+
128
+
129
+
130
+
131
+
132
+
133
+ Downloading data files: 22%|█████████████████████████████████████████████████▌ | 2/9 [00:52<02:53, 24.73s/it]
134
+
135
+ Downloading data: 91%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 30.6M/33.6M [00:01<00:00, 37.1MB/s]
136
+
137
+
138
+
139
+
140
+
141
+
142
+ Downloading data files: 33%|██████████████████████████████████████████████████████████████████████████▎ | 3/9 [01:09<02:08, 21.40s/it]
143
+
144
+ Downloading data: 99%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 33.2M/33.5M [00:01<00:00, 41.7MB/s]
145
+
146
+
147
+
148
+
149
+
150
+
151
+ Downloading data files: 44%|███████████████████████████████████████████████████████████████████████████████████████████████████ | 4/9 [01:29<01:42, 20.55s/it]
152
+
153
+ Downloading data: 94%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 31.3M/33.3M [00:01<00:00, 39.8MB/s]
154
+
155
+
156
+
157
+
158
+
159
+
160
+ Downloading data files: 56%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 5/9 [01:46<01:17, 19.42s/it]
161
+
162
+ Downloading data: 98%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 32.8M/33.4M [00:01<00:00, 35.0MB/s]
163
+
164
+
165
+
166
+
167
+
168
+
169
+
170
+
171
+
172
+ Downloading data files: 67%|████���███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 6/9 [02:09<01:02, 20.72s/it]
173
+ Downloading data: 88%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 29.5M/33.4M [00:01<00:00, 36.6MB/s]
174
+
175
+
176
+
177
+
178
+
179
+
180
+
181
+ Downloading data: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 33.6M/33.6M [00:01<00:00, 24.7MB/s]
182
+ Downloading data: 90%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 30.3M/33.6M [00:01<00:00, 38.8MB/s]
183
+
184
+
185
+
186
+
187
+
188
+
189
+ Downloading data files: 89%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 8/9 [02:45<00:19, 19.04s/it]
190
+
191
+ Downloading data: 84%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 21.0M/25.0M [00:01<00:00, 32.3MB/s]
192
+
193
+
194
+
195
+
196
+ Downloading data files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [02:59<00:00, 19.91s/it]
197
+ Downloading data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████���█████████████████████████████████████████████████████████████████████████████████████████████████████| 29.4M/29.4M [00:01<00:00, 23.8MB/s]
198
+ Downloading data: 94%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 27.5M/29.4M [00:01<00:00, 38.6MB/s]
199
+
200
+
201
+
202
+
203
+
204
+
205
+ Downloading data files: 33%|██████████████████████████████████████████████████████████████████████████▎ | 1/3 [00:17<00:34, 17.06s/it]
206
+
207
+ Downloading data: 98%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 28.9M/29.3M [00:01<00:00, 35.3MB/s]
208
+
209
+
210
+
211
+
212
+
213
+
214
+ Downloading data files: 67%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 2/3 [00:34<00:17, 17.12s/it]
215
+
216
+ Downloading data: 93%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 12.1M/13.0M [00:00<00:00, 25.5MB/s]
217
+
218
+
219
+
220
+ Downloading data files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:45<00:00, 15.23s/it]
221
+ Traceback (most recent call last):
222
+ File "/data/flax/lib/python3.8/site-packages/datasets/features/audio.py", line 84, in encode_example
223
+ import soundfile as sf # soundfile is a dependency of librosa, needed to decode audio files.
224
+ ModuleNotFoundError: No module named 'soundfile'
225
+ The above exception was the direct cause of the following exception:
226
+ Traceback (most recent call last):
227
+ File "run_flax_speech_recognition_ctc.py", line 1596, in <module>
228
+ main()
229
+ File "run_flax_speech_recognition_ctc.py", line 870, in main
230
+ raw_datasets = make_dataset(seed=training_args.seed)
231
+ File "run_flax_speech_recognition_ctc.py", line 802, in make_dataset
232
+ nst = datasets.load_dataset("NbAiLab/NST", "no-close")
233
+ File "/data/flax/lib/python3.8/site-packages/datasets/load.py", line 1746, in load_dataset
234
+ builder_instance.download_and_prepare(
235
+ File "/data/flax/lib/python3.8/site-packages/datasets/builder.py", line 704, in download_and_prepare
236
+ self._download_and_prepare(
237
+ File "/data/flax/lib/python3.8/site-packages/datasets/builder.py", line 1227, in _download_and_prepare
238
+ super()._download_and_prepare(dl_manager, verify_infos, check_duplicate_keys=verify_infos)
239
+ File "/data/flax/lib/python3.8/site-packages/datasets/builder.py", line 793, in _download_and_prepare
240
+ self._prepare_split(split_generator, **prepare_split_kwargs)
241
+ File "/data/flax/lib/python3.8/site-packages/datasets/builder.py", line 1218, in _prepare_split
242
+ example = self.info.features.encode_example(record)
243
+ File "/data/flax/lib/python3.8/site-packages/datasets/features/features.py", line 1614, in encode_example
244
+ return encode_nested_example(self, example)
245
+ File "/data/flax/lib/python3.8/site-packages/datasets/features/features.py", line 1165, in encode_nested_example
246
+ {
247
+ File "/data/flax/lib/python3.8/site-packages/datasets/features/features.py", line 1166, in <dictcomp>
248
+ k: encode_nested_example(sub_schema, sub_obj, level=level + 1)
249
+ File "/data/flax/lib/python3.8/site-packages/datasets/features/features.py", line 1220, in encode_nested_example
250
+ return schema.encode_example(obj) if obj is not None else None
251
+ File "/data/flax/lib/python3.8/site-packages/datasets/features/audio.py", line 86, in encode_example
252
+ raise ImportError("To support encoding audio data, please install 'soundfile'.") from err
253
+ ImportError: To support encoding audio data, please install 'soundfile'.
wandb/run-20220729_183213-356uc50u/files/requirements.txt ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.0.0
2
+ aiohttp==3.8.1
3
+ aiosignal==1.2.0
4
+ astunparse==1.6.3
5
+ async-timeout==4.0.2
6
+ attrs==21.4.0
7
+ backcall==0.2.0
8
+ cachetools==4.2.4
9
+ certifi==2021.10.8
10
+ charset-normalizer==2.0.10
11
+ chex==0.1.3
12
+ click==8.0.3
13
+ cloud-tpu-client==0.10
14
+ cloud-tpu-profiler==2.4.0
15
+ clu==0.0.6
16
+ colorama==0.4.5
17
+ commonmark==0.9.1
18
+ configparser==5.2.0
19
+ contextlib2==21.6.0
20
+ cycler==0.11.0
21
+ datasets==2.4.0
22
+ decorator==5.1.0
23
+ dill==0.3.4
24
+ dm-tree==0.1.6
25
+ docker-pycreds==0.4.0
26
+ etils==0.6.0
27
+ filelock==3.4.2
28
+ flatbuffers==2.0
29
+ flax==0.5.3
30
+ fonttools==4.28.5
31
+ frozenlist==1.2.0
32
+ fsspec==2021.11.1
33
+ future==0.18.2
34
+ gast==0.4.0
35
+ gitdb==4.0.9
36
+ gitpython==3.1.26
37
+ google-api-core==1.31.5
38
+ google-api-python-client==1.8.0
39
+ google-auth-httplib2==0.1.0
40
+ google-auth-oauthlib==0.4.6
41
+ google-auth==2.3.3
42
+ google-pasta==0.2.0
43
+ googleapis-common-protos==1.54.0
44
+ grpcio==1.43.0
45
+ h5py==3.6.0
46
+ httplib2==0.20.2
47
+ huggingface-hub==0.2.1
48
+ idna==3.3
49
+ importlib-metadata==4.10.0
50
+ importlib-resources==5.4.0
51
+ ipython==7.31.0
52
+ jax==0.3.15
53
+ jaxlib==0.3.15
54
+ jedi==0.18.1
55
+ joblib==1.1.0
56
+ keras-preprocessing==1.1.2
57
+ keras==2.7.0
58
+ kiwisolver==1.3.2
59
+ libclang==12.0.0
60
+ libtpu-nightly==0.1.dev20220722
61
+ markdown==3.3.6
62
+ matplotlib-inline==0.1.3
63
+ matplotlib==3.5.1
64
+ ml-collections==0.1.0
65
+ msgpack==1.0.3
66
+ multidict==5.2.0
67
+ multiprocess==0.70.12.2
68
+ numpy==1.22.0
69
+ oauth2client==4.1.3
70
+ oauthlib==3.1.1
71
+ opt-einsum==3.3.0
72
+ optax==0.1.3
73
+ packaging==21.3
74
+ pandas==1.3.5
75
+ parso==0.8.3
76
+ pathtools==0.1.2
77
+ pexpect==4.8.0
78
+ pickleshare==0.7.5
79
+ pillow==9.0.0
80
+ pip==22.2.1
81
+ pkg-resources==0.0.0
82
+ promise==2.3
83
+ prompt-toolkit==3.0.24
84
+ protobuf==3.19.1
85
+ psutil==5.9.0
86
+ ptyprocess==0.7.0
87
+ pyarrow==6.0.1
88
+ pyasn1-modules==0.2.8
89
+ pyasn1==0.4.8
90
+ pygments==2.11.1
91
+ pyparsing==3.0.6
92
+ python-dateutil==2.8.2
93
+ pytz==2021.3
94
+ pyyaml==6.0
95
+ regex==2021.11.10
96
+ requests-oauthlib==1.3.0
97
+ requests==2.27.0
98
+ responses==0.18.0
99
+ rich==11.2.0
100
+ rsa==4.8
101
+ sacremoses==0.0.46
102
+ scipy==1.7.3
103
+ sentry-sdk==1.5.2
104
+ setuptools==44.0.0
105
+ shortuuid==1.0.8
106
+ six==1.16.0
107
+ smmap==5.0.0
108
+ subprocess32==3.5.4
109
+ tensorboard-data-server==0.6.1
110
+ tensorboard-plugin-wit==1.8.0
111
+ tensorboard==2.7.0
112
+ tensorflow-cpu==2.7.0
113
+ tensorflow-datasets==4.4.0
114
+ tensorflow-estimator==2.7.0
115
+ tensorflow-io-gcs-filesystem==0.23.1
116
+ tensorflow-metadata==1.5.0
117
+ tensorflow==2.7.0
118
+ tensorstore==0.1.21
119
+ termcolor==1.1.0
120
+ tokenizers==0.11.2
121
+ toolz==0.11.2
122
+ torch==1.11.0+cpu
123
+ tqdm==4.62.3
124
+ traitlets==5.1.1
125
+ transformers==4.21.0
126
+ typing-extensions==4.3.0
127
+ uritemplate==3.0.1
128
+ urllib3==1.26.7
129
+ wandb==0.12.9
130
+ wcwidth==0.2.5
131
+ werkzeug==2.0.2
132
+ wheel==0.37.1
133
+ wrapt==1.13.3
134
+ xxhash==2.0.2
135
+ yarl==1.7.2
136
+ yaspin==2.1.0
137
+ zipp==3.7.0
wandb/run-20220729_183213-356uc50u/files/wandb-metadata.json ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.4.0-1043-gcp-x86_64-with-glibc2.29",
3
+ "python": "3.8.10",
4
+ "heartbeatAt": "2022-07-29T18:32:17.029179",
5
+ "startedAt": "2022-07-29T18:32:13.606321",
6
+ "docker": null,
7
+ "cpu_count": 96,
8
+ "cuda": null,
9
+ "args": [
10
+ "--model_name_or_path=facebook/wav2vec2-xls-r-1b",
11
+ "--hub_model_id=NbAiLab/wav2vec2-1b-npsc-nst",
12
+ "--output_dir=./",
13
+ "--overwrite_output_dir",
14
+ "--num_train_epochs=40",
15
+ "--per_device_train_batch_size=32",
16
+ "--per_device_eval_batch_size=32",
17
+ "--gradient_accumulation_steps=1",
18
+ "--learning_rate=1e-4",
19
+ "--warmup_steps=4000",
20
+ "--length_column_name=input_length",
21
+ "--evaluation_strategy=steps",
22
+ "--text_column_name=text",
23
+ "--save_steps=500",
24
+ "--eval_steps=500",
25
+ "--logging_steps=100",
26
+ "--layerdrop=0.041",
27
+ "--attention_dropout=0.094",
28
+ "--activation_dropout=0.055",
29
+ "--hidden_dropout=0.047",
30
+ "--save_total_limit=3",
31
+ "--freeze_feature_encoder",
32
+ "--feat_proj_dropout=0.04",
33
+ "--mask_time_prob=0.082",
34
+ "--mask_time_length=10",
35
+ "--mask_feature_prob=0.25",
36
+ "--mask_feature_length=64",
37
+ "--gradient_checkpointing",
38
+ "--min_duration_in_seconds=0.5",
39
+ "--max_duration_in_seconds=20.0",
40
+ "--use_auth_token",
41
+ "--seed=42",
42
+ "--group_by_length",
43
+ "--do_train",
44
+ "--do_eval",
45
+ "--push_to_hub",
46
+ "--preprocessing_num_workers=32",
47
+ "--ctc_zero_infinity",
48
+ "--do_lower_case",
49
+ "--wandb_project",
50
+ "wav2vec2",
51
+ "--wandb_name",
52
+ "wav2vec2-1b-npsc-nst",
53
+ "--remove_punctuation"
54
+ ],
55
+ "state": "running",
56
+ "program": "run_flax_speech_recognition_ctc.py",
57
+ "codePath": "run_flax_speech_recognition_ctc.py",
58
+ "git": {
59
+ "remote": "https://huggingface.co/NbAiLab/wav2vec2-1b-npsc-nst",
60
+ "commit": "63f0838b605b109a08e90f07fe84d6a94047f139"
61
+ },
62
+ "email": "versae@gmail.com",
63
+ "root": "/data/wav2vec2-1b-npsc-nst",
64
+ "host": "t1v-n-eedfb410-w-0",
65
+ "username": "javierr",
66
+ "executable": "/data/flax/bin/python"
67
+ }
wandb/run-20220729_183213-356uc50u/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"_wandb": {"runtime": 256}}
wandb/run-20220729_183213-356uc50u/logs/debug-internal.log ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2022-07-29 18:32:14,486 INFO MainThread:136862 [internal.py:wandb_internal():87] W&B internal server running at pid: 136862, started at: 2022-07-29 18:32:14.486632
2
+ 2022-07-29 18:32:14,488 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: check_version
3
+ 2022-07-29 18:32:14,489 INFO WriterThread:136862 [datastore.py:open_for_write():77] open: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/run-356uc50u.wandb
4
+ 2022-07-29 18:32:14,489 DEBUG SenderThread:136862 [sender.py:send():234] send: header
5
+ 2022-07-29 18:32:14,490 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: check_version
6
+ 2022-07-29 18:32:14,527 DEBUG SenderThread:136862 [sender.py:send():234] send: run
7
+ 2022-07-29 18:32:14,729 INFO SenderThread:136862 [dir_watcher.py:__init__():169] watching files in: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files
8
+ 2022-07-29 18:32:14,729 INFO SenderThread:136862 [sender.py:_start_run_threads():804] run started: 356uc50u with start time 1659119533
9
+ 2022-07-29 18:32:14,729 DEBUG SenderThread:136862 [sender.py:send():234] send: summary
10
+ 2022-07-29 18:32:14,729 INFO SenderThread:136862 [sender.py:_save_file():939] saving file wandb-summary.json with policy end
11
+ 2022-07-29 18:32:14,730 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: run_start
12
+ 2022-07-29 18:32:15,737 INFO Thread-8 :136862 [dir_watcher.py:_on_file_created():217] file/dir created: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/wandb-summary.json
13
+ 2022-07-29 18:32:17,028 DEBUG HandlerThread:136862 [meta.py:__init__():40] meta init
14
+ 2022-07-29 18:32:17,029 DEBUG HandlerThread:136862 [meta.py:__init__():54] meta init done
15
+ 2022-07-29 18:32:17,029 DEBUG HandlerThread:136862 [meta.py:probe():214] probe
16
+ 2022-07-29 18:32:17,030 DEBUG HandlerThread:136862 [meta.py:_setup_git():204] setup git
17
+ 2022-07-29 18:32:17,062 DEBUG HandlerThread:136862 [meta.py:_setup_git():211] setup git done
18
+ 2022-07-29 18:32:17,062 DEBUG HandlerThread:136862 [meta.py:_save_code():92] save code
19
+ 2022-07-29 18:32:17,074 DEBUG HandlerThread:136862 [meta.py:_save_code():113] save code done
20
+ 2022-07-29 18:32:17,074 DEBUG HandlerThread:136862 [meta.py:_save_patches():130] save patches
21
+ 2022-07-29 18:32:17,166 DEBUG HandlerThread:136862 [meta.py:_save_patches():172] save patches done
22
+ 2022-07-29 18:32:17,166 DEBUG HandlerThread:136862 [meta.py:_save_pip():58] save pip
23
+ 2022-07-29 18:32:17,166 DEBUG HandlerThread:136862 [meta.py:_save_pip():72] save pip done
24
+ 2022-07-29 18:32:17,166 DEBUG HandlerThread:136862 [meta.py:probe():252] probe done
25
+ 2022-07-29 18:32:17,193 DEBUG SenderThread:136862 [sender.py:send():234] send: files
26
+ 2022-07-29 18:32:17,193 INFO SenderThread:136862 [sender.py:_save_file():939] saving file wandb-metadata.json with policy now
27
+ 2022-07-29 18:32:17,193 INFO SenderThread:136862 [sender.py:_save_file():939] saving file code/run_flax_speech_recognition_ctc.py with policy now
28
+ 2022-07-29 18:32:17,199 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: stop_status
29
+ 2022-07-29 18:32:17,199 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: stop_status
30
+ 2022-07-29 18:32:17,736 INFO Thread-8 :136862 [dir_watcher.py:_on_file_created():217] file/dir created: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/code/run_flax_speech_recognition_ctc.py
31
+ 2022-07-29 18:32:17,737 INFO Thread-8 :136862 [dir_watcher.py:_on_file_created():217] file/dir created: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/requirements.txt
32
+ 2022-07-29 18:32:17,737 INFO Thread-8 :136862 [dir_watcher.py:_on_file_created():217] file/dir created: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
33
+ 2022-07-29 18:32:17,737 INFO Thread-8 :136862 [dir_watcher.py:_on_file_created():217] file/dir created: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/wandb-metadata.json
34
+ 2022-07-29 18:32:17,737 INFO Thread-8 :136862 [dir_watcher.py:_on_file_created():217] file/dir created: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/code
35
+ 2022-07-29 18:32:17,880 INFO Thread-12 :136862 [upload_job.py:push():137] Uploaded file /tmp/tmp1arbfimxwandb/oqv2t90y-code/run_flax_speech_recognition_ctc.py
36
+ 2022-07-29 18:32:18,151 INFO Thread-11 :136862 [upload_job.py:push():137] Uploaded file /tmp/tmp1arbfimxwandb/1hi0yjav-wandb-metadata.json
37
+ 2022-07-29 18:32:19,737 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
38
+ 2022-07-29 18:32:21,738 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
39
+ 2022-07-29 18:32:23,739 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
40
+ 2022-07-29 18:32:25,740 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
41
+ 2022-07-29 18:32:27,741 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
42
+ 2022-07-29 18:32:29,742 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
43
+ 2022-07-29 18:32:31,743 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
44
+ 2022-07-29 18:32:32,337 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: stop_status
45
+ 2022-07-29 18:32:32,338 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: stop_status
46
+ 2022-07-29 18:32:33,744 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
47
+ 2022-07-29 18:32:35,745 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
48
+ 2022-07-29 18:32:37,746 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
49
+ 2022-07-29 18:32:39,747 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
50
+ 2022-07-29 18:32:41,748 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
51
+ 2022-07-29 18:32:43,749 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
52
+ 2022-07-29 18:32:45,108 DEBUG SenderThread:136862 [sender.py:send():234] send: stats
53
+ 2022-07-29 18:32:45,750 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
54
+ 2022-07-29 18:32:47,497 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: stop_status
55
+ 2022-07-29 18:32:47,497 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: stop_status
56
+ 2022-07-29 18:32:47,751 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
57
+ 2022-07-29 18:32:49,752 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
58
+ 2022-07-29 18:32:51,753 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
59
+ 2022-07-29 18:32:53,753 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
60
+ 2022-07-29 18:32:55,754 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
61
+ 2022-07-29 18:32:57,755 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
62
+ 2022-07-29 18:32:59,756 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
63
+ 2022-07-29 18:33:01,757 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
64
+ 2022-07-29 18:33:02,633 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: stop_status
65
+ 2022-07-29 18:33:02,633 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: stop_status
66
+ 2022-07-29 18:33:03,758 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
67
+ 2022-07-29 18:33:05,759 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
68
+ 2022-07-29 18:33:07,760 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
69
+ 2022-07-29 18:33:09,761 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
70
+ 2022-07-29 18:33:11,761 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
71
+ 2022-07-29 18:33:13,762 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
72
+ 2022-07-29 18:33:15,183 DEBUG SenderThread:136862 [sender.py:send():234] send: stats
73
+ 2022-07-29 18:33:15,763 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
74
+ 2022-07-29 18:33:17,764 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
75
+ 2022-07-29 18:33:17,772 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: stop_status
76
+ 2022-07-29 18:33:17,772 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: stop_status
77
+ 2022-07-29 18:33:19,765 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
78
+ 2022-07-29 18:33:21,766 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
79
+ 2022-07-29 18:33:23,767 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
80
+ 2022-07-29 18:33:25,768 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
81
+ 2022-07-29 18:33:27,769 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
82
+ 2022-07-29 18:33:29,770 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
83
+ 2022-07-29 18:33:31,771 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
84
+ 2022-07-29 18:33:32,909 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: stop_status
85
+ 2022-07-29 18:33:32,909 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: stop_status
86
+ 2022-07-29 18:33:33,772 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
87
+ 2022-07-29 18:33:35,773 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
88
+ 2022-07-29 18:33:37,774 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
89
+ 2022-07-29 18:33:39,775 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
90
+ 2022-07-29 18:33:41,776 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
91
+ 2022-07-29 18:33:43,777 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
92
+ 2022-07-29 18:33:45,245 DEBUG SenderThread:136862 [sender.py:send():234] send: stats
93
+ 2022-07-29 18:33:45,778 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
94
+ 2022-07-29 18:33:47,779 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
95
+ 2022-07-29 18:33:48,051 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: stop_status
96
+ 2022-07-29 18:33:48,051 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: stop_status
97
+ 2022-07-29 18:33:49,780 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
98
+ 2022-07-29 18:33:51,781 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
99
+ 2022-07-29 18:33:53,782 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
100
+ 2022-07-29 18:33:55,783 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
101
+ 2022-07-29 18:33:57,784 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
102
+ 2022-07-29 18:33:59,785 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
103
+ 2022-07-29 18:34:01,786 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
104
+ 2022-07-29 18:34:03,192 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: stop_status
105
+ 2022-07-29 18:34:03,192 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: stop_status
106
+ 2022-07-29 18:34:03,786 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
107
+ 2022-07-29 18:34:05,787 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
108
+ 2022-07-29 18:34:07,788 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
109
+ 2022-07-29 18:34:09,789 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
110
+ 2022-07-29 18:34:11,790 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
111
+ 2022-07-29 18:34:13,791 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
112
+ 2022-07-29 18:34:15,308 DEBUG SenderThread:136862 [sender.py:send():234] send: stats
113
+ 2022-07-29 18:34:15,792 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
114
+ 2022-07-29 18:34:17,793 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
115
+ 2022-07-29 18:34:18,334 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: stop_status
116
+ 2022-07-29 18:34:18,334 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: stop_status
117
+ 2022-07-29 18:34:19,794 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
118
+ 2022-07-29 18:34:21,795 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
119
+ 2022-07-29 18:34:23,796 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
120
+ 2022-07-29 18:34:25,797 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
121
+ 2022-07-29 18:34:27,798 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
122
+ 2022-07-29 18:34:29,799 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
123
+ 2022-07-29 18:34:31,800 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
124
+ 2022-07-29 18:34:33,472 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: stop_status
125
+ 2022-07-29 18:34:33,472 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: stop_status
126
+ 2022-07-29 18:34:33,801 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
127
+ 2022-07-29 18:34:35,802 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
128
+ 2022-07-29 18:34:37,803 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
129
+ 2022-07-29 18:34:39,804 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
130
+ 2022-07-29 18:34:41,805 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
131
+ 2022-07-29 18:34:43,806 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
132
+ 2022-07-29 18:34:45,381 DEBUG SenderThread:136862 [sender.py:send():234] send: stats
133
+ 2022-07-29 18:34:45,807 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
134
+ 2022-07-29 18:34:47,808 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
135
+ 2022-07-29 18:34:48,609 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: stop_status
136
+ 2022-07-29 18:34:48,610 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: stop_status
137
+ 2022-07-29 18:34:49,809 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
138
+ 2022-07-29 18:34:51,810 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
139
+ 2022-07-29 18:34:53,811 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
140
+ 2022-07-29 18:34:55,812 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
141
+ 2022-07-29 18:34:57,813 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
142
+ 2022-07-29 18:34:59,814 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
143
+ 2022-07-29 18:35:01,815 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
144
+ 2022-07-29 18:35:03,748 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: stop_status
145
+ 2022-07-29 18:35:03,748 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: stop_status
146
+ 2022-07-29 18:35:03,815 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
147
+ 2022-07-29 18:35:05,816 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
148
+ 2022-07-29 18:35:07,817 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
149
+ 2022-07-29 18:35:09,818 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
150
+ 2022-07-29 18:35:11,819 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
151
+ 2022-07-29 18:35:13,820 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
152
+ 2022-07-29 18:35:15,454 DEBUG SenderThread:136862 [sender.py:send():234] send: stats
153
+ 2022-07-29 18:35:15,821 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
154
+ 2022-07-29 18:35:17,822 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
155
+ 2022-07-29 18:35:18,886 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: stop_status
156
+ 2022-07-29 18:35:18,886 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: stop_status
157
+ 2022-07-29 18:35:19,823 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
158
+ 2022-07-29 18:35:33,829 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
159
+ 2022-07-29 18:35:34,020 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: stop_status
160
+ 2022-07-29 18:35:34,021 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: stop_status
161
+ 2022-07-29 18:35:35,830 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
162
+ 2022-07-29 18:35:37,831 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
163
+ 2022-07-29 18:35:39,831 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
164
+ 2022-07-29 18:35:41,832 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
165
+ 2022-07-29 18:35:43,833 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
166
+ 2022-07-29 18:35:45,525 DEBUG SenderThread:136862 [sender.py:send():234] send: stats
167
+ 2022-07-29 18:35:45,834 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
168
+ 2022-07-29 18:35:47,836 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
169
+ 2022-07-29 18:35:49,158 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: stop_status
170
+ 2022-07-29 18:35:49,158 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: stop_status
171
+ 2022-07-29 18:35:49,837 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
172
+ 2022-07-29 18:35:51,838 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
173
+ 2022-07-29 18:35:53,839 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
174
+ 2022-07-29 18:35:55,840 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
175
+ 2022-07-29 18:35:57,841 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
176
+ 2022-07-29 18:35:59,842 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
177
+ 2022-07-29 18:36:01,843 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
178
+ 2022-07-29 18:36:03,844 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
179
+ 2022-07-29 18:36:04,296 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: stop_status
180
+ 2022-07-29 18:36:04,296 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: stop_status
181
+ 2022-07-29 18:36:05,845 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
182
+ 2022-07-29 18:36:07,846 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
183
+ 2022-07-29 18:36:09,846 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
184
+ 2022-07-29 18:36:11,847 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
185
+ 2022-07-29 18:36:13,848 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
186
+ 2022-07-29 18:36:15,598 DEBUG SenderThread:136862 [sender.py:send():234] send: stats
187
+ 2022-07-29 18:36:15,849 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
188
+ 2022-07-29 18:36:17,850 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
189
+ 2022-07-29 18:36:19,431 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: stop_status
190
+ 2022-07-29 18:36:19,431 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: stop_status
191
+ 2022-07-29 18:36:19,851 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
192
+ 2022-07-29 18:36:23,853 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
193
+ 2022-07-29 18:36:29,855 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
194
+ 2022-07-29 18:36:30,795 DEBUG SenderThread:136862 [sender.py:send():234] send: telemetry
195
+ 2022-07-29 18:36:30,795 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: poll_exit
196
+ 2022-07-29 18:36:30,795 DEBUG SenderThread:136862 [sender.py:send():234] send: exit
197
+ 2022-07-29 18:36:30,795 INFO SenderThread:136862 [sender.py:send_exit():366] handling exit code: 1
198
+ 2022-07-29 18:36:30,796 INFO SenderThread:136862 [sender.py:send_exit():368] handling runtime: 256
199
+ 2022-07-29 18:36:30,798 INFO SenderThread:136862 [sender.py:_save_file():939] saving file wandb-summary.json with policy end
200
+ 2022-07-29 18:36:30,798 INFO SenderThread:136862 [sender.py:send_exit():374] send defer
201
+ 2022-07-29 18:36:30,798 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: poll_exit
202
+ 2022-07-29 18:36:30,799 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: defer
203
+ 2022-07-29 18:36:30,799 INFO HandlerThread:136862 [handler.py:handle_request_defer():147] handle defer: 0
204
+ 2022-07-29 18:36:30,799 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: defer
205
+ 2022-07-29 18:36:30,799 INFO SenderThread:136862 [sender.py:send_request_defer():383] handle sender defer: 0
206
+ 2022-07-29 18:36:30,799 INFO SenderThread:136862 [sender.py:transition_state():387] send defer: 1
207
+ 2022-07-29 18:36:30,800 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: defer
208
+ 2022-07-29 18:36:30,800 INFO HandlerThread:136862 [handler.py:handle_request_defer():147] handle defer: 1
209
+ 2022-07-29 18:36:30,830 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: defer
210
+ 2022-07-29 18:36:30,830 INFO SenderThread:136862 [sender.py:send_request_defer():383] handle sender defer: 1
211
+ 2022-07-29 18:36:30,830 INFO SenderThread:136862 [sender.py:transition_state():387] send defer: 2
212
+ 2022-07-29 18:36:30,831 DEBUG SenderThread:136862 [sender.py:send():234] send: stats
213
+ 2022-07-29 18:36:30,831 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: defer
214
+ 2022-07-29 18:36:30,831 INFO HandlerThread:136862 [handler.py:handle_request_defer():147] handle defer: 2
215
+ 2022-07-29 18:36:30,831 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: defer
216
+ 2022-07-29 18:36:30,831 INFO SenderThread:136862 [sender.py:send_request_defer():383] handle sender defer: 2
217
+ 2022-07-29 18:36:30,831 INFO SenderThread:136862 [sender.py:transition_state():387] send defer: 3
218
+ 2022-07-29 18:36:30,832 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: defer
219
+ 2022-07-29 18:36:30,832 INFO HandlerThread:136862 [handler.py:handle_request_defer():147] handle defer: 3
220
+ 2022-07-29 18:36:30,832 DEBUG SenderThread:136862 [sender.py:send():234] send: summary
221
+ 2022-07-29 18:36:30,832 INFO SenderThread:136862 [sender.py:_save_file():939] saving file wandb-summary.json with policy end
222
+ 2022-07-29 18:36:30,832 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: defer
223
+ 2022-07-29 18:36:30,832 INFO SenderThread:136862 [sender.py:send_request_defer():383] handle sender defer: 3
224
+ 2022-07-29 18:36:30,832 INFO SenderThread:136862 [sender.py:transition_state():387] send defer: 4
225
+ 2022-07-29 18:36:30,832 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: defer
226
+ 2022-07-29 18:36:30,833 INFO HandlerThread:136862 [handler.py:handle_request_defer():147] handle defer: 4
227
+ 2022-07-29 18:36:30,833 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: defer
228
+ 2022-07-29 18:36:30,833 INFO SenderThread:136862 [sender.py:send_request_defer():383] handle sender defer: 4
229
+ 2022-07-29 18:36:30,856 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/wandb-summary.json
230
+ 2022-07-29 18:36:30,856 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
231
+ 2022-07-29 18:36:30,900 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: poll_exit
232
+ 2022-07-29 18:36:31,177 INFO SenderThread:136862 [sender.py:transition_state():387] send defer: 5
233
+ 2022-07-29 18:36:31,177 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: poll_exit
234
+ 2022-07-29 18:36:31,178 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: defer
235
+ 2022-07-29 18:36:31,178 INFO HandlerThread:136862 [handler.py:handle_request_defer():147] handle defer: 5
236
+ 2022-07-29 18:36:31,178 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: defer
237
+ 2022-07-29 18:36:31,178 INFO SenderThread:136862 [sender.py:send_request_defer():383] handle sender defer: 5
238
+ 2022-07-29 18:36:31,178 INFO SenderThread:136862 [dir_watcher.py:finish():283] shutting down directory watcher
239
+ 2022-07-29 18:36:31,279 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: poll_exit
240
+ 2022-07-29 18:36:31,856 INFO Thread-8 :136862 [dir_watcher.py:_on_file_modified():230] file/dir modified: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/config.yaml
241
+ 2022-07-29 18:36:31,857 INFO SenderThread:136862 [dir_watcher.py:finish():313] scan: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files
242
+ 2022-07-29 18:36:31,857 INFO SenderThread:136862 [dir_watcher.py:finish():327] scan save: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/config.yaml config.yaml
243
+ 2022-07-29 18:36:31,857 INFO SenderThread:136862 [dir_watcher.py:finish():327] scan save: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/requirements.txt requirements.txt
244
+ 2022-07-29 18:36:31,857 INFO SenderThread:136862 [dir_watcher.py:finish():327] scan save: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log output.log
245
+ 2022-07-29 18:36:31,858 INFO SenderThread:136862 [dir_watcher.py:finish():327] scan save: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/wandb-summary.json wandb-summary.json
246
+ 2022-07-29 18:36:31,858 INFO SenderThread:136862 [dir_watcher.py:finish():327] scan save: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/wandb-metadata.json wandb-metadata.json
247
+ 2022-07-29 18:36:31,863 INFO SenderThread:136862 [dir_watcher.py:finish():327] scan save: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/code/run_flax_speech_recognition_ctc.py code/run_flax_speech_recognition_ctc.py
248
+ 2022-07-29 18:36:31,863 INFO SenderThread:136862 [sender.py:transition_state():387] send defer: 6
249
+ 2022-07-29 18:36:31,864 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: poll_exit
250
+ 2022-07-29 18:36:31,870 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: defer
251
+ 2022-07-29 18:36:31,870 INFO HandlerThread:136862 [handler.py:handle_request_defer():147] handle defer: 6
252
+ 2022-07-29 18:36:31,870 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: defer
253
+ 2022-07-29 18:36:31,870 INFO SenderThread:136862 [sender.py:send_request_defer():383] handle sender defer: 6
254
+ 2022-07-29 18:36:31,870 INFO SenderThread:136862 [file_pusher.py:finish():177] shutting down file pusher
255
+ 2022-07-29 18:36:31,965 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: poll_exit
256
+ 2022-07-29 18:36:31,965 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: poll_exit
257
+ 2022-07-29 18:36:32,067 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: poll_exit
258
+ 2022-07-29 18:36:32,067 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: poll_exit
259
+ 2022-07-29 18:36:32,169 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: poll_exit
260
+ 2022-07-29 18:36:32,169 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: poll_exit
261
+ 2022-07-29 18:36:32,270 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: poll_exit
262
+ 2022-07-29 18:36:32,271 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: poll_exit
263
+ 2022-07-29 18:36:32,332 INFO Thread-13 :136862 [upload_job.py:push():137] Uploaded file /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/config.yaml
264
+ 2022-07-29 18:36:32,338 INFO Thread-16 :136862 [upload_job.py:push():137] Uploaded file /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/wandb-summary.json
265
+ 2022-07-29 18:36:32,340 INFO Thread-14 :136862 [upload_job.py:push():137] Uploaded file /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/requirements.txt
266
+ 2022-07-29 18:36:32,348 INFO Thread-15 :136862 [upload_job.py:push():137] Uploaded file /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/files/output.log
267
+ 2022-07-29 18:36:32,372 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: poll_exit
268
+ 2022-07-29 18:36:32,372 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: poll_exit
269
+ 2022-07-29 18:36:32,473 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: poll_exit
270
+ 2022-07-29 18:36:32,474 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: poll_exit
271
+ 2022-07-29 18:36:32,548 INFO Thread-7 :136862 [sender.py:transition_state():387] send defer: 7
272
+ 2022-07-29 18:36:32,549 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: defer
273
+ 2022-07-29 18:36:32,549 INFO HandlerThread:136862 [handler.py:handle_request_defer():147] handle defer: 7
274
+ 2022-07-29 18:36:32,549 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: defer
275
+ 2022-07-29 18:36:32,549 INFO SenderThread:136862 [sender.py:send_request_defer():383] handle sender defer: 7
276
+ 2022-07-29 18:36:32,575 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: poll_exit
277
+ 2022-07-29 18:36:33,221 INFO SenderThread:136862 [sender.py:transition_state():387] send defer: 8
278
+ 2022-07-29 18:36:33,221 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: poll_exit
279
+ 2022-07-29 18:36:33,222 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: defer
280
+ 2022-07-29 18:36:33,222 INFO HandlerThread:136862 [handler.py:handle_request_defer():147] handle defer: 8
281
+ 2022-07-29 18:36:33,222 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: defer
282
+ 2022-07-29 18:36:33,222 INFO SenderThread:136862 [sender.py:send_request_defer():383] handle sender defer: 8
283
+ 2022-07-29 18:36:33,222 INFO SenderThread:136862 [sender.py:transition_state():387] send defer: 9
284
+ 2022-07-29 18:36:33,223 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: defer
285
+ 2022-07-29 18:36:33,223 INFO HandlerThread:136862 [handler.py:handle_request_defer():147] handle defer: 9
286
+ 2022-07-29 18:36:33,223 DEBUG SenderThread:136862 [sender.py:send():234] send: final
287
+ 2022-07-29 18:36:33,223 DEBUG SenderThread:136862 [sender.py:send():234] send: footer
288
+ 2022-07-29 18:36:33,224 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: defer
289
+ 2022-07-29 18:36:33,224 INFO SenderThread:136862 [sender.py:send_request_defer():383] handle sender defer: 9
290
+ 2022-07-29 18:36:33,323 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: poll_exit
291
+ 2022-07-29 18:36:33,323 DEBUG SenderThread:136862 [sender.py:send_request():248] send_request: poll_exit
292
+ 2022-07-29 18:36:33,323 INFO SenderThread:136862 [file_pusher.py:join():182] waiting for file pusher
293
+ 2022-07-29 18:36:33,628 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: get_summary
294
+ 2022-07-29 18:36:33,629 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: sampled_history
295
+ 2022-07-29 18:36:33,630 DEBUG HandlerThread:136862 [handler.py:handle_request():130] handle_request: shutdown
296
+ 2022-07-29 18:36:33,630 INFO HandlerThread:136862 [handler.py:finish():731] shutting down handler
297
+ 2022-07-29 18:36:34,224 INFO WriterThread:136862 [datastore.py:close():281] close: /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/run-356uc50u.wandb
298
+ 2022-07-29 18:36:34,627 INFO SenderThread:136862 [sender.py:finish():1070] shutting down sender
299
+ 2022-07-29 18:36:34,627 INFO SenderThread:136862 [file_pusher.py:finish():177] shutting down file pusher
300
+ 2022-07-29 18:36:34,627 INFO SenderThread:136862 [file_pusher.py:join():182] waiting for file pusher
301
+ 2022-07-29 18:36:34,630 INFO MainThread:136862 [internal.py:handle_exit():77] Internal process exited
wandb/run-20220729_183213-356uc50u/logs/debug.log ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2022-07-29 18:32:13,607 INFO MainThread:135604 [wandb_setup.py:_flush():71] setting env: {'project': 'wav2vec2', 'entity': 'NbAiLab'}
2
+ 2022-07-29 18:32:13,608 INFO MainThread:135604 [wandb_setup.py:_flush():71] setting login settings: {}
3
+ 2022-07-29 18:32:13,608 INFO MainThread:135604 [wandb_init.py:_log_setup():371] Logging user logs to /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/logs/debug.log
4
+ 2022-07-29 18:32:13,608 INFO MainThread:135604 [wandb_init.py:_log_setup():372] Logging internal logs to /data/wav2vec2-1b-npsc-nst/wandb/run-20220729_183213-356uc50u/logs/debug-internal.log
5
+ 2022-07-29 18:32:13,608 INFO MainThread:135604 [wandb_init.py:init():404] calling init triggers
6
+ 2022-07-29 18:32:13,608 INFO MainThread:135604 [wandb_init.py:init():409] wandb.init called with sweep_config: {}
7
+ config: {}
8
+ 2022-07-29 18:32:13,608 INFO MainThread:135604 [wandb_init.py:init():460] starting backend
9
+ 2022-07-29 18:32:13,608 INFO MainThread:135604 [backend.py:_multiprocessing_setup():99] multiprocessing start_methods=fork,spawn,forkserver, using: spawn
10
+ 2022-07-29 18:32:13,667 INFO MainThread:135604 [backend.py:ensure_launched():216] starting backend process...
11
+ 2022-07-29 18:32:13,694 INFO MainThread:135604 [backend.py:ensure_launched():221] started backend process with pid: 136862
12
+ 2022-07-29 18:32:13,698 INFO MainThread:135604 [wandb_init.py:init():469] backend started and connected
13
+ 2022-07-29 18:32:13,713 INFO MainThread:135604 [wandb_init.py:init():533] updated telemetry
14
+ 2022-07-29 18:32:13,778 INFO MainThread:135604 [wandb_init.py:init():563] communicating current version
15
+ 2022-07-29 18:32:14,526 INFO MainThread:135604 [wandb_init.py:init():568] got version response upgrade_message: "wandb version 0.12.21 is available! To upgrade, please run:\n $ pip install wandb --upgrade"
16
+
17
+ 2022-07-29 18:32:14,526 INFO MainThread:135604 [wandb_init.py:init():578] communicating run to backend with 30 second timeout
18
+ 2022-07-29 18:32:14,730 INFO MainThread:135604 [wandb_init.py:init():606] starting run threads in backend
19
+ 2022-07-29 18:32:17,197 INFO MainThread:135604 [wandb_run.py:_console_start():1810] atexit reg
20
+ 2022-07-29 18:32:17,197 INFO MainThread:135604 [wandb_run.py:_redirect():1684] redirect: SettingsConsole.REDIRECT
21
+ 2022-07-29 18:32:17,198 INFO MainThread:135604 [wandb_run.py:_redirect():1689] Redirecting console.
22
+ 2022-07-29 18:32:17,204 INFO MainThread:135604 [wandb_run.py:_redirect():1745] Redirects installed.
23
+ 2022-07-29 18:32:17,204 INFO MainThread:135604 [wandb_init.py:init():633] run started, returning control to user process
24
+ 2022-07-29 18:36:28,486 INFO MainThread:135604 [wandb_run.py:_atexit_cleanup():1780] got exitcode: 1
25
+ 2022-07-29 18:36:28,502 INFO MainThread:135604 [wandb_run.py:_restore():1752] restore
26
+ 2022-07-29 18:36:30,799 INFO MainThread:135604 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
27
+ wandb_count: 1
28
+ other_count: 1
29
+ }
30
+ pusher_stats {
31
+ uploaded_bytes: 73662
32
+ total_bytes: 73662
33
+ }
34
+
35
+ 2022-07-29 18:36:31,178 INFO MainThread:135604 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
36
+ wandb_count: 1
37
+ other_count: 1
38
+ }
39
+ pusher_stats {
40
+ uploaded_bytes: 73662
41
+ total_bytes: 73662
42
+ }
43
+
44
+ 2022-07-29 18:36:31,864 INFO MainThread:135604 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
45
+ wandb_count: 5
46
+ other_count: 1
47
+ }
48
+ pusher_stats {
49
+ uploaded_bytes: 73662
50
+ total_bytes: 98442
51
+ }
52
+
53
+ 2022-07-29 18:36:31,966 INFO MainThread:135604 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
54
+ wandb_count: 5
55
+ other_count: 1
56
+ }
57
+ pusher_stats {
58
+ uploaded_bytes: 73662
59
+ total_bytes: 98442
60
+ }
61
+
62
+ 2022-07-29 18:36:32,068 INFO MainThread:135604 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
63
+ wandb_count: 5
64
+ other_count: 1
65
+ }
66
+ pusher_stats {
67
+ uploaded_bytes: 98442
68
+ total_bytes: 98442
69
+ }
70
+
71
+ 2022-07-29 18:36:32,170 INFO MainThread:135604 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
72
+ wandb_count: 5
73
+ other_count: 1
74
+ }
75
+ pusher_stats {
76
+ uploaded_bytes: 98442
77
+ total_bytes: 98442
78
+ }
79
+
80
+ 2022-07-29 18:36:32,271 INFO MainThread:135604 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
81
+ wandb_count: 5
82
+ other_count: 1
83
+ }
84
+ pusher_stats {
85
+ uploaded_bytes: 98442
86
+ total_bytes: 98442
87
+ }
88
+
89
+ 2022-07-29 18:36:32,373 INFO MainThread:135604 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
90
+ wandb_count: 5
91
+ other_count: 1
92
+ }
93
+ pusher_stats {
94
+ uploaded_bytes: 98442
95
+ total_bytes: 98442
96
+ }
97
+
98
+ 2022-07-29 18:36:32,474 INFO MainThread:135604 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
99
+ wandb_count: 5
100
+ other_count: 1
101
+ }
102
+ pusher_stats {
103
+ uploaded_bytes: 98442
104
+ total_bytes: 98442
105
+ }
106
+
107
+ 2022-07-29 18:36:33,222 INFO MainThread:135604 [wandb_run.py:_wait_for_finish():1912] got exit ret: file_counts {
108
+ wandb_count: 5
109
+ other_count: 1
110
+ }
111
+ pusher_stats {
112
+ uploaded_bytes: 98442
113
+ total_bytes: 98442
114
+ }
115
+
116
+ 2022-07-29 18:36:33,627 INFO MainThread:135604 [wandb_run.py:_wait_for_finish():1912] got exit ret: done: true
117
+ exit_result {
118
+ }
119
+ file_counts {
120
+ wandb_count: 5
121
+ other_count: 1
122
+ }
123
+ pusher_stats {
124
+ uploaded_bytes: 98442
125
+ total_bytes: 98442
126
+ }
127
+ local_info {
128
+ }
129
+
130
+ 2022-07-29 18:36:35,126 INFO MainThread:135604 [wandb_run.py:_append_files():2180] logging synced files
wandb/run-20220729_183213-356uc50u/run-356uc50u.wandb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:abb83d82b3b2c65b07ff90b0a90d06625b3524560ee7653e99485761b1a56795
3
+ size 73924
wandb/run-20220729_184558-17ksemgv/files/code/run_flax_speech_recognition_ctc.py ADDED
@@ -0,0 +1,1596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2022 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the Flax library models for connectionist temporal classification (CTC) speech recognition.
18
+ """
19
+ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
+
21
+ import logging
22
+ import math
23
+ import os
24
+ import re
25
+ import sys
26
+ import time
27
+ from dataclasses import dataclass, field
28
+ from pathlib import Path
29
+ from typing import Any, Callable, Dict, List, Optional, Union
30
+
31
+ import datasets
32
+ import numpy as np
33
+ from datasets import DatasetDict, load_dataset, load_metric
34
+ from tqdm import tqdm
35
+
36
+ import flax
37
+ import jax
38
+ import jax.numpy as jnp
39
+ import optax
40
+ import transformers
41
+ import wandb as wandb
42
+ from flax import core, jax_utils, struct, traverse_util
43
+ from flax.jax_utils import unreplicate, pad_shard_unpad
44
+ from flax.training.common_utils import get_metrics, shard, shard_prng_key
45
+ from huggingface_hub import Repository
46
+ from models import Wav2Vec2Config, FlaxWav2Vec2ForCTC
47
+ from optax._src import linear_algebra
48
+ from transformers import (
49
+ AutoFeatureExtractor,
50
+ AutoProcessor,
51
+ AutoTokenizer,
52
+ HfArgumentParser,
53
+ TrainingArguments,
54
+ is_tensorboard_available,
55
+ set_seed,
56
+ )
57
+ from transformers.file_utils import get_full_repo_name
58
+ from transformers.utils import check_min_version
59
+ from transformers.utils.versions import require_version
60
+
61
+
62
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
63
+ check_min_version("4.17.0.dev0")
64
+
65
+ require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
66
+
67
+ logger = logging.getLogger(__name__)
68
+
69
+
70
+ @flax.struct.dataclass
71
+ class ModelArguments:
72
+ """
73
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
74
+ """
75
+
76
+ model_name_or_path: str = field(
77
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
78
+ )
79
+ config_name: Optional[str] = field(
80
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
81
+ )
82
+ tokenizer_name: Optional[str] = field(
83
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
84
+ )
85
+ feature_extractor_name: Optional[str] = field(
86
+ default=None, metadata={"help": "feature extractor name or path if not the same as model_name"}
87
+ )
88
+ cache_dir: Optional[str] = field(
89
+ default=None,
90
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
91
+ )
92
+ use_fast_tokenizer: bool = field(
93
+ default=True,
94
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
95
+ )
96
+ model_revision: str = field(
97
+ default="main",
98
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
99
+ )
100
+ use_auth_token: bool = field(
101
+ default=False,
102
+ metadata={
103
+ "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
104
+ "with private models)."
105
+ },
106
+ )
107
+ freeze_feature_encoder: bool = field(
108
+ default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
109
+ )
110
+ attention_dropout: float = field(
111
+ default=0.0, metadata={"help": "The dropout ratio for the attention probabilities."}
112
+ )
113
+ activation_dropout: float = field(
114
+ default=0.1,
115
+ metadata={
116
+ "help": "The hidden activation dropout probability in the embeddings, encoder, and pooler."
117
+ },
118
+ )
119
+ hidden_dropout: float = field(
120
+ default=0.1,
121
+ metadata={
122
+ "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler."
123
+ },
124
+ )
125
+ feat_proj_dropout: float = field(
126
+ default=0.0,
127
+ metadata={
128
+ "help": "The feat proj dropout probability for feature encoder representations."
129
+ },
130
+ )
131
+ final_dropout: float = field(
132
+ default=0.0,
133
+ metadata={"help": "The dropout probability for the final projection layer."},
134
+ )
135
+ mask_time_prob: float = field(
136
+ default=0.1,
137
+ metadata={
138
+ "help": "The spec aug dropout probability for feature encoder representations."
139
+ },
140
+ )
141
+ mask_time_length: int = field(
142
+ default=10,
143
+ metadata={"help": "Length of vector span to mask along the time axis."},
144
+ )
145
+ mask_feature_prob: float = field(
146
+ default=0.0,
147
+ metadata={
148
+ "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector"
149
+ "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis."
150
+ },
151
+ )
152
+ mask_feature_length: int = field(
153
+ default=10,
154
+ metadata={"help": "Length of vector span to mask along the feature axis."},
155
+ )
156
+ layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."})
157
+ ctc_loss_reduction: Optional[str] = field(
158
+ default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
159
+ )
160
+ ctc_zero_infinity: Optional[bool] = field(
161
+ default=False, metadata={"help": "If True, will try yo aboud the CTC loss goinf to infinity."}
162
+ )
163
+
164
+
165
+ @flax.struct.dataclass
166
+ class DataTrainingArguments:
167
+ """
168
+ Arguments pertaining to what data we are going to input our model for training and eval.
169
+ """
170
+
171
+ dataset_name: str = field(
172
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
173
+ )
174
+ dataset_config_name: Optional[str] = field(
175
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
176
+ )
177
+ text_column: Optional[str] = field(
178
+ default=None,
179
+ metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
180
+ )
181
+ dataset_cache_dir: Optional[str] = field(
182
+ default=None, metadata={"help": "Path to cache directory for saving and loading datasets"}
183
+ )
184
+ overwrite_cache: bool = field(
185
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
186
+ )
187
+ preprocessing_num_workers: Optional[int] = field(
188
+ default=None,
189
+ metadata={"help": "The number of processes to use for the preprocessing."},
190
+ )
191
+ max_train_samples: Optional[int] = field(
192
+ default=None,
193
+ metadata={
194
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
195
+ "value if set."
196
+ },
197
+ )
198
+ max_eval_samples: Optional[int] = field(
199
+ default=None,
200
+ metadata={
201
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
202
+ "value if set."
203
+ },
204
+ )
205
+ max_test_samples: Optional[int] = field(
206
+ default=None,
207
+ metadata={
208
+ "help": "For debugging purposes or quicker training, truncate the number of test examples to this "
209
+ "value if set."
210
+ },
211
+ )
212
+ audio_column_name: str = field(
213
+ default="audio",
214
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
215
+ )
216
+ text_column_name: str = field(
217
+ default="text",
218
+ metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
219
+ )
220
+ max_duration_in_seconds: float = field(
221
+ default=20.0,
222
+ metadata={
223
+ "help": "Truncate audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`"
224
+ },
225
+ )
226
+ min_duration_in_seconds: float = field(
227
+ default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
228
+ )
229
+ max_label_length: Optional[int] = field(
230
+ default=512,
231
+ metadata={
232
+ "help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
233
+ "than this will be filtered."
234
+ },
235
+ )
236
+ min_label_length: Optional[int] = field(
237
+ default=2,
238
+ metadata={
239
+ "help": "The minimum total sequence length for target text after tokenization. Sequences shorter "
240
+ "than this will be filtered."
241
+ },
242
+ )
243
+ pad_input_to_multiple_of: Optional[int] = field(
244
+ default=32000,
245
+ metadata={
246
+ "help": "If set will pad the input sequence to a multiple of the provided value. "
247
+ "This is important to avoid triggering recompilations on TPU."
248
+ },
249
+ )
250
+ pad_target_to_multiple_of: Optional[int] = field(
251
+ default=None,
252
+ metadata={
253
+ "help": "If set will pad the target sequence to a multiple of the provided value. "
254
+ "This is important to avoid triggering recompilations on TPU."
255
+ },
256
+ )
257
+ preprocessing_only: bool = field(
258
+ default=False,
259
+ metadata={
260
+ "help": "Whether to only do data preprocessing and skip training. "
261
+ "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
262
+ "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
263
+ "so that the cached datasets can consequently be loaded in distributed training"
264
+ },
265
+ )
266
+ train_split_name: str = field(
267
+ default="train",
268
+ metadata={
269
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
270
+ },
271
+ )
272
+ eval_split_name: str = field(
273
+ default="validation",
274
+ metadata={
275
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
276
+ },
277
+ )
278
+ do_lower_case: bool = field(
279
+ default=True,
280
+ metadata={"help": "Whether the target text should be lower cased."},
281
+ )
282
+ wandb_project: str = field(
283
+ default="flax-speech-recognition-ctc",
284
+ metadata={"help": "The name of the wandb project."},
285
+ )
286
+ wandb_name: str = field(
287
+ default=None,
288
+ metadata={"help": "The name of the wandb run."},
289
+ )
290
+ wandb_job_type: str = field(
291
+ default="CTC",
292
+ metadata={"help": "The name of the wandb job type."},
293
+ )
294
+ test_split_name: str = field(
295
+ default="test",
296
+ metadata={"help": "The name of the test data set split to use (via the datasets library). Defaults to 'test'"},
297
+ )
298
+ remove_punctuation: bool = field(
299
+ default=False, metadata={"help": "Whether or not to remove punctuation during training."}
300
+ )
301
+
302
+
303
+ # @flax.struct.dataclass
304
+ @dataclass
305
+ class FlaxTrainingArguments(TrainingArguments):
306
+ precision: str = field(
307
+ default="full",
308
+ metadata={
309
+ "help": "Whether to enable mixed-precision training. If true, the optimizer is stored in half-precision (bfloat16) and computations are executed in half-precision"
310
+ "**Note that this only specifies the dtype of the computation and optimizer state. It does not influence the dtype of model parameters.**"
311
+ },
312
+ )
313
+ matmul_precision: str = field(
314
+ default="default",
315
+ metadata={
316
+ "help": "Default floating-point precision of internal computations used in TPU matrix multiplications and convolutions. "
317
+ "This configuration option controls the default precision for JAX operations that take an optional precision argument (e.g. `lax.conv_general_dilated` and `lax.dot`). "
318
+ "This configuration option does not change the behaviours of such calls with explicit precision arguments; "
319
+ "it only changes the behaviors of calls with no such argument provided. "
320
+ "One of `['highest', 'float32', 'high', 'bfloat16_3x', 'default', 'bfloat16', 'fastest', None]`."
321
+ },
322
+ )
323
+ multisteps: bool = field(
324
+ default=False,
325
+ metadata={
326
+ "help": "Whether to use Optax MultiSteps for gradient accumulation. If `False` (default) and `gradient_accumulation_steps > 1`, "
327
+ "a custom gradient accumulation implementation will be employed."
328
+ },
329
+ )
330
+
331
+
332
+ def to_fp32(t):
333
+ return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
334
+
335
+
336
+ def to_bf16(t):
337
+ return jax.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t)
338
+
339
+
340
+ class MixedPrecisionTrainState(struct.PyTreeNode):
341
+ """Train state for use with a single Optax optimizer.
342
+ Adapted from flax train_state https://github.com/google/flax/blob/main/flax/training/train_state.py
343
+
344
+ Synopsis::
345
+
346
+ state = TrainState.create(
347
+ apply_fn=model.apply,
348
+ params=variables['params'],
349
+ tx=tx)
350
+ grad_fn = jax.grad(make_loss_fn(state.apply_fn))
351
+ for batch in data:
352
+ grads = grad_fn(state.params, batch)
353
+ state = state.apply_gradients(grads=grads)
354
+
355
+ Args:
356
+ step: Counter starts at 0 and is incremented by every call to
357
+ `.apply_gradients()`.
358
+ apply_fn: Usually set to `model.apply()`. Kept in this dataclass for
359
+ convenience to have a shorter params list for the `train_step()` function
360
+ in your training loop.
361
+ params: The parameters to be updated by `tx` and used by `apply_fn`.
362
+ tx: An Optax gradient transformation.
363
+ opt_state: The state for `tx`.
364
+ dropout_rng: PRNG key for stochastic operations.
365
+ bf16: Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training.
366
+ """
367
+
368
+ step: int
369
+ apply_fn: Callable = struct.field(pytree_node=False)
370
+ get_attention_mask_fn: Callable = struct.field(pytree_node=False)
371
+ params: core.FrozenDict[str, Any]
372
+ tx: optax.GradientTransformation = struct.field(pytree_node=False)
373
+ opt_state: optax.OptState
374
+ dropout_rng: jnp.ndarray
375
+ max_grad_norm: Optional[float] = 1.0
376
+
377
+ def apply_gradients(self, *, grads, to_dtype, **kwargs):
378
+ """Updates `step`, `params`, `opt_state` and `**kwargs` in return value.
379
+
380
+ Note that internally this function calls `.tx.update()` followed by a call
381
+ to `optax.apply_updates()` to update `params` and `opt_state`.
382
+
383
+ Args:
384
+ grads: Gradients that have the same pytree structure as `.params`.
385
+ **kwargs: Additional dataclass attributes that should be `.replace()`-ed.
386
+
387
+ Returns:
388
+ An updated instance of `self` with `step` incremented by one, `params`
389
+ and `opt_state` updated by applying `grads`, and additional attributes
390
+ replaced as specified by `kwargs`.
391
+ """
392
+
393
+ # clip gradients by global l2 norm
394
+ casted_max_grad_norm = to_dtype(self.max_grad_norm)
395
+ g_norm = linear_algebra.global_norm(grads)
396
+ g_norm = jnp.maximum(casted_max_grad_norm, g_norm)
397
+ grads = jax.tree_map(lambda t: (t / g_norm) * casted_max_grad_norm, grads)
398
+
399
+ # perform update step in fp32 and subsequently downcast optimizer states if mixed precision training
400
+ # grads and opt_state in bf16 (need to upcast), params in fp32 (leave as is)
401
+ updates, new_opt_state = self.tx.update(to_fp32(grads), to_fp32(self.opt_state), self.params)
402
+
403
+ new_params = optax.apply_updates(self.params, updates)
404
+ return self.replace(
405
+ step=self.step + 1,
406
+ params=new_params,
407
+ opt_state=to_dtype(new_opt_state),
408
+ **kwargs,
409
+ )
410
+
411
+ @classmethod
412
+ def create(cls, *, apply_fn, params, tx, to_dtype, **kwargs):
413
+ """Creates a new instance with `step=0` and initialized `opt_state`."""
414
+ # downcast optimizer state to bf16 if mixed-precision training
415
+ opt_state = tx.init(to_dtype(params)) if tx is not None else None
416
+ return cls(
417
+ step=0,
418
+ apply_fn=apply_fn,
419
+ params=params,
420
+ tx=tx,
421
+ opt_state=opt_state,
422
+ **kwargs,
423
+ )
424
+
425
+ def replicate(self):
426
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
427
+
428
+
429
+ @flax.struct.dataclass
430
+ class FlaxDataCollatorSpeechSeq2SeqWithPadding:
431
+ """
432
+ Data collator that will dynamically pad the inputs received.
433
+ Args:
434
+ processor ([`Wav2Vec2Processor`])
435
+ The processor used for proccessing the data.
436
+ decoder_start_token_id (:obj: `int`)
437
+ The begin-of-sentence of the decoder.
438
+ input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
439
+ Select a strategy to pad the returned input sequences (according to the model's padding side and padding index)
440
+ among:
441
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
442
+ sequence if provided).
443
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
444
+ maximum acceptable input length for the model if that argument is not provided.
445
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
446
+ different lengths).
447
+ target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
448
+ Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
449
+ See above for details.
450
+ max_input_length (:obj:`float`, `optional`):
451
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
452
+ pad_input_to_multiple_of (:obj:`int`, `optional`):
453
+ If set will pad the input sequence to a multiple of the provided value.
454
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
455
+ 7.5 (Volta).
456
+ pad_target_to_multiple_of (:obj:`int`, `optional`):
457
+ If set will pad the target sequence to a multiple of the provided value.
458
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
459
+ 7.5 (Volta).
460
+ """
461
+
462
+ processor: Any
463
+ input_padding: Union[bool, str] = "longest"
464
+ label_padding: Union[bool, str] = "max_length"
465
+ pad_input_to_multiple_of: Optional[int] = None
466
+ pad_to_multiple_of_label: Optional[int] = None
467
+ max_input_length: Optional[float] = None
468
+ max_label_length: Optional[float] = None
469
+
470
+ def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
471
+ # split inputs and labels since they have to be of different lengths and need
472
+ # different padding methods
473
+ input_features = [{"input_values": feature["input_values"]} for feature in features]
474
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
475
+
476
+ # reformat list to dict and set to pytorch format
477
+ batch = self.processor.feature_extractor.pad(
478
+ input_features,
479
+ max_length=self.max_input_length,
480
+ padding=self.input_padding,
481
+ pad_to_multiple_of=self.pad_input_to_multiple_of,
482
+ return_tensors="np",
483
+ )
484
+
485
+ labels_batch = self.processor.tokenizer.pad(
486
+ label_features,
487
+ max_length=self.max_label_length,
488
+ padding=self.label_padding,
489
+ pad_to_multiple_of=self.pad_to_multiple_of_label,
490
+ return_tensors="np",
491
+ )
492
+
493
+ labels = labels_batch["input_ids"]
494
+ labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1))
495
+ labels = labels.filled(fill_value=-100)
496
+
497
+ batch["labels"] = labels
498
+
499
+ return batch
500
+
501
+
502
+ def get_grouped_indices(
503
+ dataset, batch_size: int, rng: Optional[List[int]] = None, mega_batch_mult: Optional[int] = None
504
+ ) -> np.array:
505
+ """
506
+ Adapted from the `get_length_grouped_indices` function in the PyTorch Trainer utils file (https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_pt_utils.py#L486)
507
+ Function that returns a list of indices in which each slice of `batch_size` consecutive indices correspond to elements of similar
508
+ lengths. To do this, the indices are:
509
+
510
+ - randomly permuted (if a JAX rng is specified)
511
+ - grouped in mega-batches of size `mega_batch_mult * batch_size`
512
+ - sorted by length in each mega-batch
513
+
514
+ The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
515
+ maximum length placed first, so that an OOM happens sooner rather than later.
516
+ """
517
+ lengths = dataset["input_length"]
518
+
519
+ # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
520
+ if mega_batch_mult is None:
521
+ mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
522
+ # Just in case, for tiny datasets
523
+ if mega_batch_mult == 0:
524
+ mega_batch_mult = 1
525
+
526
+ # We need to use JAX for the random permutation as the PRNG key will be set based on the seed outside of the sampler.
527
+ num_samples = len(lengths)
528
+ indices = jax.random.permutation(rng, np.arange(num_samples)) if rng is not None else np.arange(num_samples)
529
+
530
+ megabatch_size = mega_batch_mult * batch_size
531
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
532
+ megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches]
533
+
534
+ # The rest is to get the biggest batch first.
535
+ # Since each megabatch is sorted by descending length, the longest element is the first
536
+ megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
537
+ max_idx = np.argmax(megabatch_maximums).item()
538
+ # Switch to put the longest batch in first position
539
+ # (note that this is different to the PT grouped sampler in which we only put the longest element in the first position, and not its batch)
540
+ megabatches[0], megabatches[max_idx] = megabatches[max_idx], megabatches[0]
541
+
542
+ megabatches = np.array([i for megabatch in megabatches for i in megabatch])
543
+
544
+ return megabatches
545
+
546
+
547
+ def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray:
548
+ """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
549
+ the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned."""
550
+ num_samples = len(samples_idx)
551
+ if drop_last:
552
+ samples_to_remove = num_samples % batch_size
553
+ if samples_to_remove != 0:
554
+ samples_idx = samples_idx[:-samples_to_remove]
555
+ sections_split = num_samples // batch_size
556
+ samples_idx = samples_idx.reshape((sections_split, batch_size))
557
+ else:
558
+ sections_split = math.ceil(num_samples / batch_size)
559
+ samples_idx = np.array_split(samples_idx, sections_split)
560
+ return samples_idx
561
+
562
+
563
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
564
+ summary_writer.scalar("train_time", train_time, step)
565
+
566
+ train_metrics = get_metrics(train_metrics)
567
+ for key, vals in train_metrics.items():
568
+ tag = f"train_{key}"
569
+ for i, val in enumerate(vals):
570
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
571
+
572
+
573
+ def write_eval_metric(summary_writer, eval_metrics, step, pred_str=None):
574
+ for metric_name, value in eval_metrics.items():
575
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
576
+
577
+ if pred_str is not None:
578
+ # write output actual predictions for debugging
579
+ summary_writer.text("eval_predictions", "\n".join(pred_str), step)
580
+
581
+
582
+ def write_wandb_log(metrics, step, prefix=None):
583
+ if jax.process_index() == 0:
584
+ log_metrics = {}
585
+ for k, v in metrics.items():
586
+ if "layer" in k:
587
+ log_metrics[f"{k}/"] = v
588
+ elif prefix is not None:
589
+ log_metrics[f"{prefix}/{k}"] = v
590
+ else:
591
+ log_metrics[k] = v
592
+ wandb.log(log_metrics, step)
593
+
594
+
595
+ def write_wandb_pred(pred_str, label_str, step, num_log=50, prefix="eval"):
596
+ if jax.process_index() == 0:
597
+ # convert str data to a wandb compatible format
598
+ str_data = [[label_str[i], pred_str[i]] for i in range(len(pred_str))]
599
+ # we'll log the first 50 predictions for each epoch
600
+ wandb.log(
601
+ {
602
+ f"{prefix}/step_{int(step / 1000)}k": wandb.Table(
603
+ columns=["label_str", "pred_str"], data=str_data[:num_log]
604
+ )
605
+ },
606
+ step,
607
+ )
608
+
609
+
610
+ def create_learning_rate_fn(
611
+ num_train_steps: int, num_warmup_steps: int, learning_rate: float
612
+ ) -> Callable[[int], jnp.array]:
613
+ """Returns a linear warmup, linear_decay learning rate function."""
614
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
615
+ decay_fn = optax.linear_schedule(
616
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
617
+ )
618
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
619
+ return schedule_fn
620
+
621
+
622
+ def ctc_loss(
623
+ logits,
624
+ logits_attention_mask,
625
+ labels,
626
+ blank_id,
627
+ loss_reduction="mean",
628
+ output_emission_dict=False,
629
+ log_epsilon=-100000.0,
630
+ ):
631
+ """Computes CTC loss.
632
+ This function performs forward computation over an FSA with `N * 2` states
633
+ where `N` is the max number of labels. The states are split into two groups:
634
+ Phi states and emission states. a phi-state accepts repetition of
635
+ phi (blank)-symbols and transits to emission state when the correct label is
636
+ observed. An emission state accepts repetition of the label and transits to
637
+ the next phi states at any time (so called epsilon-transition).
638
+ Below, `B` denotes the batch size, `T` denotes the time steps in `logits`,
639
+ and `N` denotes the time steps in `labels`.
640
+ Args:
641
+ logits: (B, T, K)-array containing log-probabilities of each class.
642
+ logitpaddings: (B, T)-array. Padding indicators for `logits`.
643
+ labels: (B, N)-array containing reference integer labels.
644
+ labelpaddings: (B, N)-array. Padding indicators for `labels`. Currently,
645
+ `labels` must be right-padded, i.e. each row of `labelpaddings` must be
646
+ repetition of zeroes, followed by repetition of ones.
647
+ blank_id: Id for blank token.
648
+ loss_reduction: one of "mean", "sum", "default"
649
+ - "none": no reduction is applied.
650
+ - "mean": output loss will be divided by target lengths and then the
651
+ mean over the batch is taken.
652
+ - "sum": output loss are summed over batch
653
+ output_emission_dict: whether to output additional information about the emission probs
654
+ Returns:
655
+ A pair of `(per_seq_loss, aux)`.
656
+ per_seq_loss:
657
+ (B,)-array containing loss values for each sequence in the batch.
658
+ aux: Dictionary containing interim variables used for computing losses.
659
+ aux['logalpha_phi']: (T, B, N+1)-array. Log-forward-probabilities of each
660
+ phi-state corresponding to the n-th label.
661
+ aux['logalpha_emit']: (T, B, N)-array. Log-forward-probabilities of each
662
+ emission-state corresponding to the n-th label.
663
+ aux['logprobs_phi']: (T, B, 1)-array. Probability of the phi-symbol
664
+ corresponding to each time frame.
665
+ aux['logprobs_emit']: (T, B, N)-array. Probability of the n-th label
666
+ corresponding to each time frame.
667
+ """
668
+ # label paddings are indicated by -100
669
+ labelpaddings = labels < 0
670
+ # logit paddings are the inverse of attention_mask
671
+ logitpaddings = ~logits_attention_mask
672
+
673
+ # Copied from https://github.com/tensorflow/lingvo/blob/master/lingvo/jax/layers/ctc_objectives.py
674
+ batchsize, unused_maxinputlen, num_classes = logits.shape
675
+ batchsize_, maxlabellen = labels.shape
676
+
677
+ logprobs = jax.nn.log_softmax(logits)
678
+ labellens = maxlabellen - jnp.sum(labelpaddings, axis=1).astype(jnp.int32)
679
+
680
+ # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1].
681
+ repeat = (labels[:, :-1] == labels[:, 1:]).astype(jnp.float32)
682
+ repeat = jnp.pad(repeat, ((0, 0), (0, 1)))
683
+
684
+ logprobs_phi = logprobs[:, :, blank_id : blank_id + 1] # [B, T, 1]
685
+ logprobs_phi = jnp.transpose(logprobs_phi, (1, 0, 2)) # [T, B, 1]
686
+
687
+ one_hot = jax.nn.one_hot(labels, num_classes=num_classes) # [B, N, K]
688
+ logprobs_emit = jnp.einsum("btk,bnk->btn", logprobs, one_hot)
689
+ logprobs_emit = jnp.transpose(logprobs_emit, (1, 0, 2)) # [T, B, N]
690
+
691
+ logalpha_phi_init = jnp.ones((batchsize, maxlabellen + 1)) * log_epsilon # [B, N]
692
+ logalpha_phi_init = logalpha_phi_init.at[:, 0].set(0.0)
693
+ logalpha_emit_init = jnp.ones((batchsize, maxlabellen)) * log_epsilon # [B, N]
694
+
695
+ def loop_body(prev, x):
696
+ prev_phi, prev_emit = prev
697
+ # emit-to-phi epsilon transition, except if the next label is repetition
698
+ prev_phi_orig = prev_phi
699
+ prev_phi = prev_phi.at[:, 1:].set(jnp.logaddexp(prev_phi[:, 1:], prev_emit + log_epsilon * repeat))
700
+
701
+ logprob_emit, logprob_phi, pad = x
702
+
703
+ # phi-to-emit transition
704
+ next_emit = jnp.logaddexp(prev_phi[:, :-1] + logprob_emit, prev_emit + logprob_emit)
705
+ # self-loop transition
706
+ next_phi = prev_phi + logprob_phi
707
+ # emit-to-phi blank transition only when the next label is repetition
708
+ next_phi = next_phi.at[:, 1:].set(
709
+ jnp.logaddexp(next_phi[:, 1:], prev_emit + logprob_phi + log_epsilon * (1.0 - repeat))
710
+ )
711
+
712
+ pad = pad.reshape((batchsize, 1))
713
+ next_emit = pad * prev_emit + (1.0 - pad) * next_emit
714
+ next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi
715
+
716
+ return (next_phi, next_emit), (next_phi, next_emit)
717
+
718
+ xs = (logprobs_emit, logprobs_phi, logitpaddings.transpose((1, 0)))
719
+ _, (logalpha_phi, logalpha_emit) = jax.lax.scan(loop_body, (logalpha_phi_init, logalpha_emit_init), xs)
720
+
721
+ # last row needs to be updated with the last epsilon transition
722
+ logalpha_phi_last = logalpha_phi[-1].at[:, 1:].set(jnp.logaddexp(logalpha_phi[-1, :, 1:], logalpha_emit[-1]))
723
+ logalpha_phi = logalpha_phi.at[-1].set(logalpha_phi_last)
724
+
725
+ # extract per_seq_loss
726
+ one_hot = jax.nn.one_hot(labellens, num_classes=maxlabellen + 1) # [B, N+1]
727
+ per_seq_loss = -jnp.einsum("bn,bn->b", logalpha_phi_last, one_hot)
728
+
729
+ if loss_reduction == "mean":
730
+ target_lengths = labelpaddings.shape[-1] - labelpaddings.sum(axis=-1)
731
+ loss = (per_seq_loss / target_lengths).mean()
732
+ elif loss_reduction == "sum":
733
+ loss = per_seq_loss.sum()
734
+ else:
735
+ loss = per_seq_loss
736
+
737
+ if not output_emission_dict:
738
+ return loss
739
+
740
+ return loss, {
741
+ "logalpha_phi": logalpha_phi,
742
+ "logalpha_emit": logalpha_emit,
743
+ "logprobs_phi": logprobs_phi,
744
+ "logprobs_emit": logprobs_emit,
745
+ }
746
+
747
+
748
+ def make_dataset(seed=42):
749
+ # Pre-processing dataset
750
+ import re
751
+
752
+ def map_nst(entry):
753
+ text = entry["text"].lower()
754
+ text = text.replace("(...vær stille under dette opptaket...)", "")
755
+ text = re.sub('[áàâ]', 'a', text)
756
+ text = re.sub('[ä]', 'æ', text)
757
+ text = re.sub('[éèëê]', 'e', text)
758
+ text = re.sub('[íìïî]', 'i', text)
759
+ text = re.sub('[óòöô]', 'o', text)
760
+ text = re.sub('[ö]', 'ø', text)
761
+ text = re.sub('[ç]', 'c', text)
762
+ text = re.sub('[úùüû]', 'u', text)
763
+ # text = re.sub('\\(?=(Punktum|Komma|Utropstegn|Spørsmålstegn))', ' ', text)
764
+ text = re.sub('\s+', ' ', text)
765
+ return {"text": text}
766
+
767
+ def filter_nst(entry):
768
+ if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)):
769
+ return False # Too short
770
+ if re.match(entry["type"], "pIW|CA"):
771
+ return False # Spelling out words
772
+ return True
773
+
774
+ def filter_npsc(entry):
775
+ # False if there are digits in the text
776
+ if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)):
777
+ return False # Too short
778
+ if re.search("\d", entry["text"]):
779
+ return False
780
+ return True
781
+
782
+ def map_npsc(entry):
783
+ batch = {"text": entry["text"].lower()}
784
+ batch["text"] = re.sub('[áàâ]', 'a', batch["text"])
785
+ batch["text"] = re.sub('[ä]', 'æ', batch["text"])
786
+ batch["text"] = re.sub('[éèëê]', 'e', batch["text"])
787
+ batch["text"] = re.sub('[íìïî]', 'i', batch["text"])
788
+ batch["text"] = re.sub('[óòöô]', 'o', batch["text"])
789
+ batch["text"] = re.sub('[ö]', 'ø', batch["text"])
790
+ batch["text"] = re.sub('[ç]', 'c', batch["text"])
791
+ batch["text"] = re.sub('[úùüû]', 'u', batch["text"])
792
+ batch["text"] = re.sub('\s', ' ', batch["text"])
793
+ batch["text"] = re.sub('<ee>', 'eee', batch["text"])
794
+ batch["text"] = re.sub('<qq>', 'qqq', batch["text"])
795
+ batch["text"] = re.sub('<mm>', 'mmm', batch["text"])
796
+ batch["text"] = re.sub('<inaudible>', 'xxx', batch["text"])
797
+ # batch["text"] = re.sub('<inaudible>', '?', batch["text"])
798
+ if "<" in batch["text"]:
799
+ raise ValueError(batch["text"])
800
+ return batch
801
+
802
+ nst = datasets.load_dataset("NbAiLab/NST", "no-close")
803
+ npsc = datasets.load_dataset("NbAiLab/NPSC", "16K_mp3")
804
+ # TODO NST_hesitate
805
+
806
+ split = len(npsc["train"]) / (len(npsc["train"]) + len(npsc["validation"])) # Use same train/val ratio as NPSC
807
+ nst_train = nst["train"].train_test_split(train_size=split, seed=seed)
808
+ nst["train"] = nst_train["train"]
809
+ nst["validation"] = nst_train["test"]
810
+
811
+ nst = nst.filter(filter_nst).map(map_nst).shuffle(seed=seed)
812
+ npsc = npsc.filter(filter_npsc).map(map_npsc).shuffle(seed=seed)
813
+
814
+ npsc_base = npsc.remove_columns([col for col in npsc["train"].column_names if col not in ["text", "audio"]])
815
+ nst_base = nst.remove_columns([col for col in nst["train"].column_names if col not in ["text", "audio"]])
816
+
817
+ combined = {}
818
+ for split in "train", "validation", "test":
819
+ probs = np.array([len(nst_base[split]), len(npsc_base[split])]) # Weight by number of examples
820
+ probs = (probs / probs.sum()).tolist()
821
+ comb = datasets.interleave_datasets([nst_base[split], npsc_base[split]], probabilities=probs, seed=seed)
822
+ combined[split] = comb
823
+
824
+ return datasets.DatasetDict(**combined)
825
+
826
+ def main():
827
+ # 1. Parse input arguments
828
+ # See all possible arguments in src/transformers/training_args.py
829
+ # or by passing the --help flag to this script.
830
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
831
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FlaxTrainingArguments))
832
+
833
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
834
+ # If we pass only one argument to the script and it's the path to a json file,
835
+ # let's parse it to get our arguments.
836
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
837
+ else:
838
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
839
+
840
+ # 2. Setup logging
841
+ # Make one log on every process with the configuration for debugging.
842
+ logging.basicConfig(
843
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
844
+ datefmt="%m/%d/%Y %H:%M:%S",
845
+ handlers=[logging.StreamHandler(sys.stdout)],
846
+ )
847
+ # Set the verbosity to info of the Transformers logger.
848
+ # We only want one process per machine to log things on the screen.
849
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
850
+ if jax.process_index() == 0:
851
+ datasets.utils.logging.set_verbosity_warning()
852
+ transformers.utils.logging.set_verbosity_info()
853
+ else:
854
+ datasets.utils.logging.set_verbosity_error()
855
+ transformers.utils.logging.set_verbosity_error()
856
+
857
+ # Set up wandb run
858
+ if jax.process_index() == 0:
859
+ wandb.init(project=data_args.wandb_project, name=data_args.wandb_name, job_type=data_args.wandb_job_type)
860
+
861
+ logger.info("Training/evaluation parameters %s", training_args)
862
+
863
+ # Set the default TPU matmul precision and display the number of devices
864
+ jax.config.update("jax_default_matmul_precision", training_args.matmul_precision)
865
+ logger.info(f"JAX devices: {jax.device_count()}, matmul precision: {training_args.matmul_precision}")
866
+
867
+ # 4. Load dataset
868
+
869
+ set_seed(training_args.seed)
870
+ raw_datasets = make_dataset(seed=training_args.seed)
871
+
872
+ # raw_datasets = DatasetDict()
873
+
874
+ # if training_args.do_train:
875
+ # raw_datasets["train"] = load_dataset(
876
+ # data_args.dataset_name,
877
+ # data_args.dataset_config_name,
878
+ # split=data_args.train_split_name,
879
+ # cache_dir=data_args.dataset_cache_dir,
880
+ # use_auth_token=True if model_args.use_auth_token else None,
881
+ # )
882
+
883
+ # if training_args.do_eval:
884
+ # raw_datasets["eval"] = load_dataset(
885
+ # data_args.dataset_name,
886
+ # data_args.dataset_config_name,
887
+ # split=data_args.eval_split_name,
888
+ # cache_dir=data_args.dataset_cache_dir,
889
+ # use_auth_token=True if model_args.use_auth_token else None,
890
+ # )
891
+
892
+ # if training_args.do_predict:
893
+ # test_split = data_args.test_split_name.split("+")
894
+ # for split in test_split:
895
+ # raw_datasets[split] = load_dataset(
896
+ # data_args.dataset_name,
897
+ # data_args.dataset_config_name,
898
+ # split=split,
899
+ # cache_dir=data_args.dataset_cache_dir,
900
+ # use_auth_token=True if model_args.use_auth_token else None,
901
+ # )
902
+
903
+ if not training_args.do_train and not training_args.do_eval and not training_args.do_predict:
904
+ raise ValueError(
905
+ "Cannot not train, not do evaluation and not do prediction. At least one of "
906
+ "training, evaluation or prediction has to be done."
907
+ )
908
+
909
+ # if not training, there is no need to run multiple epochs
910
+ if not training_args.do_train:
911
+ training_args.num_train_epochs = 1
912
+
913
+ if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
914
+ raise ValueError(
915
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
916
+ "Make sure to set `--audio_column_name` to the correct audio column - one of "
917
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
918
+ )
919
+
920
+ if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names:
921
+ raise ValueError(
922
+ f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
923
+ "Make sure to set `--text_column_name` to the correct text column - one of "
924
+ f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
925
+ )
926
+
927
+ # 5. Load pretrained model, tokenizer, and feature extractor
928
+ #
929
+ # Distributed training:
930
+ # The .from_pretrained methods guarantee that only one local process can concurrently
931
+ config = Wav2Vec2Config.from_pretrained(
932
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
933
+ cache_dir=model_args.cache_dir,
934
+ revision=model_args.model_revision,
935
+ use_auth_token=True if model_args.use_auth_token else None,
936
+ )
937
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
938
+ model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
939
+ cache_dir=model_args.cache_dir,
940
+ revision=model_args.model_revision,
941
+ use_auth_token=True if model_args.use_auth_token else None,
942
+ )
943
+ tokenizer = AutoTokenizer.from_pretrained(
944
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
945
+ cache_dir=model_args.cache_dir,
946
+ revision=model_args.model_revision,
947
+ use_auth_token=True if model_args.use_auth_token else None,
948
+ )
949
+ # update config according to training args, model args, and tokenizer attributes
950
+ config.update(
951
+ {
952
+ "feat_proj_dropout": model_args.feat_proj_dropout,
953
+ "attention_dropout": model_args.attention_dropout,
954
+ "hidden_dropout": model_args.hidden_dropout,
955
+ "final_dropout": model_args.final_dropout,
956
+ "mask_time_prob": model_args.mask_time_prob,
957
+ "mask_time_length": model_args.mask_time_length,
958
+ "mask_feature_prob": model_args.mask_feature_prob,
959
+ "mask_feature_length": model_args.mask_feature_length,
960
+ "gradient_checkpointing": training_args.gradient_checkpointing,
961
+ "layerdrop": model_args.layerdrop,
962
+ "ctc_loss_reduction": model_args.ctc_loss_reduction,
963
+ "ctc_zero_infinity": model_args.ctc_zero_infinity,
964
+ "pad_token_id": tokenizer.pad_token_id,
965
+ "vocab_size": tokenizer.vocab_size, # len(tokenizer),
966
+ "activation_dropout": model_args.activation_dropout,
967
+ }
968
+ )
969
+
970
+ if tokenizer.do_lower_case and data_args.dataset_name != "librispeech_asr":
971
+ raise ValueError(
972
+ "Setting the tokenizer attribute `do_lower_case` to `True` converts all input strings to "
973
+ "uppercase prior to tokenization. This should only be done when the tokenizer is built on an uppercased corpus,"
974
+ "i.e. for the dataset `librispeech_asr` only. If your dataset is not `librispeech_asr`, the tokenizer is mostly likely "
975
+ "built on an lowercased corpus. In this case, set `tokenizer.do_lower_case` to ``False`."
976
+ )
977
+
978
+ if training_args.precision == "full_mixed":
979
+ dtype = jnp.bfloat16
980
+ training_args.mixed_precision = True
981
+ elif training_args.precision == "half_mixed":
982
+ dtype = jnp.bfloat16
983
+ training_args.mixed_precision = False
984
+ else:
985
+ dtype = jnp.float32
986
+ training_args.mixed_precision = False
987
+
988
+ model = FlaxWav2Vec2ForCTC.from_pretrained(
989
+ model_args.model_name_or_path,
990
+ config=config,
991
+ dtype=dtype,
992
+ cache_dir=model_args.cache_dir,
993
+ revision=model_args.model_revision,
994
+ use_auth_token=True if model_args.use_auth_token else None,
995
+ )
996
+
997
+ # 6. Resample speech dataset ALWAYS
998
+ raw_datasets = raw_datasets.cast_column(
999
+ data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
1000
+ )
1001
+
1002
+ # 7. Preprocessing the datasets.
1003
+ # We need to read the audio files as arrays and tokenize the targets.
1004
+ max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
1005
+ min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate)
1006
+ max_target_length = data_args.max_label_length
1007
+ min_target_length = data_args.min_label_length
1008
+ pad_input_to_multiple_of = data_args.pad_input_to_multiple_of
1009
+ audio_column_name = data_args.audio_column_name
1010
+ num_workers = data_args.preprocessing_num_workers
1011
+ text_column_name = data_args.text_column_name
1012
+ model_input_name = feature_extractor.model_input_names[0]
1013
+ do_lower_case = data_args.do_lower_case
1014
+ dataset_name = data_args.dataset_name
1015
+ chars_to_ignore = ', ? . ! - ; : " “ % ‘ ” ?'.split(" ")
1016
+ chars_to_ignore_regex = f'[{"".join(chars_to_ignore)}]'
1017
+ # gigaspeech_punctuation = {" <comma>": ",", " <period>": ".", " <questionmark>": "?", " <exclamationpoint>": "!"}
1018
+ # gigaspeech_disfluencies = ["<other>", "<sil>"]
1019
+ # swb_disfluencies = ["[noise]", "[laughter]", "[silence]", "<a_aside>", "<b_aside>", "<e_aside>", "[laughter-",
1020
+ # "[vocalized-noise]", "_1"]
1021
+ # swb_punctuations = ["{", "}", "[", "]-", "]"]
1022
+ # earnings_disfluencies = ["<crosstalk>", "<affirmative>", "<inaudible>", "inaudible", "<laugh>", "<unk>"]
1023
+ ignore_segments = ["ignore_time_segment_in_scoring", "<noise>", "<music>", "[noise]", "[laughter]", "[silence]",
1024
+ "[vocalized-noise]", "<crosstalk>", "<affirmative>", "<inaudible>", "<laugh>", "<other>", "<sil>", ""]
1025
+
1026
+ if training_args.do_train and data_args.max_train_samples is not None:
1027
+ raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
1028
+
1029
+ if training_args.do_eval and data_args.max_eval_samples is not None:
1030
+ raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
1031
+
1032
+ if training_args.do_predict and data_args.max_test_samples is not None:
1033
+ for split in test_split:
1034
+ raw_datasets[split] = raw_datasets[split].select(range(data_args.max_eval_samples))
1035
+
1036
+ if training_args.do_train and data_args.remove_punctuation:
1037
+
1038
+ def remove_punctuation(batch):
1039
+ batch[text_column_name] = (
1040
+ re.sub(chars_to_ignore_regex, "", batch[text_column_name]).replace("'", "").replace('"', "")
1041
+ )
1042
+
1043
+ raw_datasets["train"] = raw_datasets["train"].map(
1044
+ remove_punctuation,
1045
+ num_proc=data_args.preprocessing_num_workers,
1046
+ desc="removing punctuation from train split",
1047
+ )
1048
+
1049
+ # filter data where the targets are ignored in scoring
1050
+ def is_target_labels(input_str):
1051
+ return input_str.lower() not in ignore_segments
1052
+
1053
+ raw_datasets = raw_datasets.filter(
1054
+ is_target_labels,
1055
+ num_proc=num_workers,
1056
+ input_columns=[text_column_name],
1057
+ desc="filtering data where the targets are ignored in scoring",
1058
+ )
1059
+
1060
+ def prepare_dataset(batch):
1061
+ # process audio
1062
+ try:
1063
+ sample = batch[audio_column_name]
1064
+ except ValueError:
1065
+ sample = {"array": np.array([0.]), "sampling_rate": feature_extractor.sampling_rate}
1066
+ inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
1067
+ # process audio length
1068
+ batch[model_input_name] = inputs.input_values[0]
1069
+ batch["input_length"] = len(batch["input_values"])
1070
+
1071
+ # process targets
1072
+ input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
1073
+
1074
+ # if dataset_name == "google/xtreme_s":
1075
+ # # Finally, we tokenize the processed text
1076
+ # batch["labels"] = tokenizer(input_str).input_ids
1077
+ # batch["labels_length"] = len(batch["labels"])
1078
+ # return batch
1079
+
1080
+ # # Common Voice 9
1081
+ # if input_str.startswith('"') and input_str.endswith('"'):
1082
+ # # we can remove trailing quotation marks as they do not affect the transcription
1083
+ # input_str = input_str[1:-1]
1084
+ # # normalize quotation marks
1085
+ # input_str = re.sub(r'["“”]', '"', input_str)
1086
+ # # normalize apostrophes
1087
+ # input_str = re.sub(r"[’']", "'", input_str)
1088
+ # # normalize hyphens
1089
+ # input_str = re.sub(r"[—–]", "-", input_str)
1090
+ # # replace double quotation marks with single
1091
+ # input_str = input_str.replace('""', '"')
1092
+ # if dataset_name == "mozilla-foundation/common_voice_9_0" and len(input_str):
1093
+ # # for CV9, we'll normalize the text to always finish with punctuation
1094
+ # if input_str[-1] not in [".", "?", "!"]:
1095
+ # input_str = input_str + "."
1096
+
1097
+ # # TEDLIUM-3
1098
+ # # delete the <unk> token from the text and replace spaced apostrophes with un-spaced
1099
+ # input_str = input_str.replace("<unk>", "").replace(" '", "'")
1100
+
1101
+ # # GigaSpeech
1102
+ # for disfluency in gigaspeech_disfluencies:
1103
+ # input_str = input_str.replace(disfluency, "")
1104
+ # # convert spelled out punctuation to symbolic form
1105
+ # for punctuation, replacement in gigaspeech_punctuation.items():
1106
+ # input_str = input_str.replace(punctuation, replacement)
1107
+ # if dataset_name == "speechcolab/gigaspeech" and len(input_str):
1108
+ # # for GS, we'll normalize the text to always finish with punctuation
1109
+ # if input_str[-1] not in [".", "?", "!"]:
1110
+ # input_str = input_str + "."
1111
+
1112
+ # # SWB
1113
+ # for disfluency in swb_disfluencies:
1114
+ # input_str = input_str.replace(disfluency, "")
1115
+ # # remove parenthesised text (test data only)
1116
+ # input_str = re.sub("[\(].*?[\)]", "", input_str)
1117
+ # for punctuation in swb_punctuations:
1118
+ # input_str = input_str.replace(punctuation, "")
1119
+ # # replace anomalous words with their correct transcriptions
1120
+ # split_str = input_str.split("/")
1121
+ # if len(split_str) > 1:
1122
+ # input_str = " ".join(
1123
+ # [" ".join([" ".join(i.split(" ")[:-1]) for i in split_str])] + [split_str[-1].split(" ")[-1]])
1124
+
1125
+ # # Earnings 22
1126
+ # for disfluency in earnings_disfluencies:
1127
+ # input_str = input_str.replace(disfluency, "")
1128
+ # # replace mal-formatted ellipsis
1129
+ # input_str = input_str.replace("…", ".")
1130
+
1131
+ # JIWER compliance
1132
+ # remove multiple spaces
1133
+ input_str = re.sub(r"\s\s+", " ", input_str)
1134
+ # strip trailing spaces
1135
+ input_str = input_str.strip()
1136
+
1137
+ # Finally, we tokenize the processed text
1138
+ batch["labels"] = tokenizer(input_str).input_ids
1139
+ batch["labels_length"] = len(batch["labels"])
1140
+ return batch
1141
+
1142
+ vectorized_datasets = raw_datasets.map(
1143
+ prepare_dataset,
1144
+ remove_columns=next(iter(raw_datasets.values())).column_names,
1145
+ num_proc=num_workers,
1146
+ desc="preprocess dataset",
1147
+ )
1148
+
1149
+ # filter data with inputs shorter than min_input_length or longer than max_input_length
1150
+ def is_audio_in_length_range(length):
1151
+ return length > min_input_length and length < max_input_length
1152
+
1153
+ vectorized_datasets = vectorized_datasets.filter(
1154
+ is_audio_in_length_range,
1155
+ num_proc=num_workers,
1156
+ input_columns=["input_length"],
1157
+ )
1158
+
1159
+ # filter data with targets shorter than min_target_length or longer than max_target_length
1160
+ def is_labels_in_length_range(length):
1161
+ return length > min_target_length # and length < max_target_length
1162
+
1163
+ vectorized_datasets = vectorized_datasets.filter(
1164
+ is_labels_in_length_range,
1165
+ num_proc=num_workers,
1166
+ input_columns=["labels_length"],
1167
+ )
1168
+
1169
+ # for large datasets it is advised to run the preprocessing on a
1170
+ # single machine first with `args.preprocessing_only` since there will mostly likely
1171
+ # be a timeout when running the script in distributed mode.
1172
+ # In a second step `args.preprocessing_only` can then be set to `False` to load the
1173
+ # cached dataset
1174
+ if data_args.preprocessing_only:
1175
+ cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
1176
+ logger.info(f"Data preprocessing finished. Files cached at {cache}.")
1177
+ return
1178
+
1179
+ # 8. Load Metrics
1180
+ wer_metric = load_metric("wer")
1181
+ cer_metric = load_metric("cer")
1182
+
1183
+ def compute_metrics(pred_ids: List[List[int]], label_ids: List[List[int]]):
1184
+ padded_ids = np.where(np.asarray(label_ids) == -100, tokenizer.pad_token_id, np.asarray(label_ids))
1185
+
1186
+ pred_str = tokenizer.batch_decode(pred_ids)
1187
+ # we do not want to group tokens when computing the metrics
1188
+ label_str = tokenizer.batch_decode(padded_ids, group_tokens=False)
1189
+
1190
+ wer = wer_metric.compute(predictions=pred_str, references=label_str)
1191
+ cer = cer_metric.compute(predictions=pred_str, references=label_str)
1192
+
1193
+ return {"wer": wer, "cer": cer}, pred_str, label_str
1194
+
1195
+ # 9. save feature extractor, tokenizer and config
1196
+ feature_extractor.save_pretrained(training_args.output_dir)
1197
+ tokenizer.save_pretrained(training_args.output_dir)
1198
+ config.save_pretrained(training_args.output_dir)
1199
+
1200
+ processor = AutoProcessor.from_pretrained(training_args.output_dir)
1201
+
1202
+ data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding(
1203
+ processor=processor,
1204
+ input_padding="longest",
1205
+ pad_input_to_multiple_of=pad_input_to_multiple_of,
1206
+ max_label_length=data_args.max_label_length,
1207
+ )
1208
+
1209
+ # Enable tensorboard only on the master node
1210
+ has_tensorboard = is_tensorboard_available()
1211
+ if has_tensorboard and jax.process_index() == 0:
1212
+ try:
1213
+ from flax.metrics.tensorboard import SummaryWriter
1214
+
1215
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
1216
+ except ImportError as ie:
1217
+ has_tensorboard = False
1218
+ logger.warning(
1219
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
1220
+ )
1221
+ else:
1222
+ logger.warning(
1223
+ "Unable to display metrics through TensorBoard because the package is not installed: "
1224
+ "Please run `pip install tensorboard` to enable."
1225
+ )
1226
+
1227
+ # 10. Handle the repository creation
1228
+ if training_args.push_to_hub:
1229
+ with open(os.path.join(training_args.output_dir, ".gitattributes"), "r+") as f:
1230
+ git_lfs_extensions = f.read()
1231
+ if "*.wandb" not in git_lfs_extensions:
1232
+ f.write("*.wandb filter=lfs diff=lfs merge=lfs -text")
1233
+ if training_args.hub_model_id is None:
1234
+ repo_name = get_full_repo_name(
1235
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
1236
+ )
1237
+ else:
1238
+ repo_name = training_args.hub_model_id
1239
+ repo = Repository(training_args.output_dir, clone_from=repo_name)
1240
+
1241
+ # 11. Initialize our training
1242
+ rng = jax.random.PRNGKey(training_args.seed)
1243
+ rng, dropout_rng = jax.random.split(rng)
1244
+
1245
+ # Store some constants
1246
+ max_steps = int(training_args.max_steps)
1247
+ gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
1248
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
1249
+ batch_size_per_update = train_batch_size * gradient_accumulation_steps
1250
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
1251
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
1252
+ to_dtype = to_bf16 if training_args.mixed_precision else to_fp32
1253
+
1254
+ if training_args.do_train:
1255
+ num_train_samples = len(vectorized_datasets["train"])
1256
+ steps_per_epoch = num_train_samples // batch_size_per_update
1257
+ if max_steps > 0:
1258
+ num_epochs = -(training_args.max_steps // -steps_per_epoch)
1259
+ total_train_steps = max_steps
1260
+ else:
1261
+ num_epochs = int(training_args.num_train_epochs)
1262
+ total_train_steps = steps_per_epoch * num_epochs
1263
+
1264
+ # Create learning rate schedule
1265
+ # Create learning rate schedule
1266
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
1267
+ total_train_steps,
1268
+ training_args.warmup_steps,
1269
+ training_args.learning_rate,
1270
+ )
1271
+
1272
+ # We use Optax's "masking" functionality to not apply weight decay
1273
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
1274
+ # mask boolean with the same structure as the parameters.
1275
+ # The mask is True for parameters that should be decayed.
1276
+ # Note that this mask is specifically adapted for FlaxWav2Vec2 and FlaxBart.
1277
+ # For FlaxT5, one should correct the layer norm parameter naming
1278
+ # accordingly - see `run_t5_mlm_flax.py` e.g.
1279
+ def decay_mask_fn(params):
1280
+ flat_params = traverse_util.flatten_dict(params)
1281
+ layer_norm_params = [
1282
+ (name, "scale")
1283
+ for name in ["layer_norm", "self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
1284
+ ]
1285
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
1286
+ return traverse_util.unflatten_dict(flat_mask)
1287
+
1288
+ if training_args.adafactor:
1289
+ # Create Adafactor optimizer
1290
+ optim = optax.adafactor(
1291
+ learning_rate=linear_decay_lr_schedule_fn,
1292
+ dtype_momentum=jnp.bfloat16 if training_args.mixed_precision else jnp.float32,
1293
+ weight_decay_rate=training_args.weight_decay,
1294
+ weight_decay_mask=decay_mask_fn,
1295
+ )
1296
+ else:
1297
+ # Create AdamW optimizer
1298
+ optim = optax.adamw(
1299
+ learning_rate=linear_decay_lr_schedule_fn,
1300
+ b1=training_args.adam_beta1,
1301
+ b2=training_args.adam_beta2,
1302
+ eps=training_args.adam_epsilon,
1303
+ weight_decay=training_args.weight_decay,
1304
+ mask=decay_mask_fn,
1305
+ )
1306
+
1307
+ # Optax MultiSteps for gradient accumulation. We'll only call this optimizer transformation if gradient accumulation is required (i.e. gradient accumulation steps > 1)
1308
+ if training_args.multisteps and gradient_accumulation_steps > 1:
1309
+ optim = optax.MultiSteps(optim, gradient_accumulation_steps, use_grad_mean=False)
1310
+ else:
1311
+ num_epochs = 0
1312
+ total_train_steps = 0
1313
+ num_train_samples = 0
1314
+ optim = None
1315
+
1316
+ # Setup train state
1317
+ state = MixedPrecisionTrainState.create(
1318
+ apply_fn=model.__call__,
1319
+ get_attention_mask_fn=model._get_feature_vector_attention_mask,
1320
+ params=model.params,
1321
+ tx=optim,
1322
+ to_dtype=to_dtype,
1323
+ dropout_rng=dropout_rng,
1324
+ max_grad_norm=training_args.max_grad_norm,
1325
+ )
1326
+
1327
+ # Replicate the train state on each device
1328
+ state = state.replicate()
1329
+ blank_id = model.config.pad_token_id
1330
+
1331
+ # Define gradient update step fn
1332
+ def train_step(state, batch):
1333
+ # only one single rng per grad step, with or without accumulation, as the graph should be identical over one effective training batch
1334
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
1335
+
1336
+ def compute_loss(params, minibatch):
1337
+ labels = minibatch.pop("labels")
1338
+ logits = state.apply_fn(
1339
+ **minibatch,
1340
+ params=params,
1341
+ dropout_rng=dropout_rng,
1342
+ freeze_feature_encoder=model_args.freeze_feature_encoder,
1343
+ train=True,
1344
+ )[0]
1345
+ logits_mask = state.get_attention_mask_fn(logits.shape[1], batch["attention_mask"])
1346
+ loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean")
1347
+
1348
+ return loss
1349
+
1350
+ grad_fn = jax.value_and_grad(compute_loss)
1351
+
1352
+ if gradient_accumulation_steps == 1 or training_args.multisteps:
1353
+ loss, grad = grad_fn(to_dtype(state.params), batch)
1354
+
1355
+ # Custom gradient accumulation
1356
+ else:
1357
+ # add a first dimension over gradient_accumulation_steps for minibatch slices
1358
+ batch = jax.tree_map(
1359
+ lambda x: x.reshape(
1360
+ gradient_accumulation_steps, training_args.per_device_train_batch_size, *x.shape[1::]
1361
+ ),
1362
+ batch,
1363
+ )
1364
+
1365
+ def accum_minibatch_step(accum_grad, minibatch):
1366
+ # compute loss, num labels and grad over minibatch and accumulate
1367
+ loss, grad = grad_fn(to_dtype(state.params), minibatch)
1368
+ return jax.tree_map(jnp.add, accum_grad, grad), loss
1369
+
1370
+ # create an initial state for accumulating losses, num labels and gradients
1371
+ init_grad = jax.tree_map(jnp.zeros_like, to_dtype(state.params))
1372
+ # loop accum minibatch step over the number of gradient accumulation steps
1373
+ grad, loss = jax.lax.scan(accum_minibatch_step, init_grad, batch)
1374
+
1375
+ # update state
1376
+ new_state = state.apply_gradients(
1377
+ grads=grad,
1378
+ dropout_rng=new_dropout_rng,
1379
+ to_dtype=to_dtype,
1380
+ )
1381
+
1382
+ # compute gradient norms over all layers and globally for detailed monitoring
1383
+ layer_grad_norm = jax.tree_map(jnp.linalg.norm, grad)
1384
+ logs = {
1385
+ "layer_grad_norm": layer_grad_norm,
1386
+ "grad_norm": jnp.linalg.norm(jax.tree_util.tree_leaves(layer_grad_norm)),
1387
+ }
1388
+
1389
+ # compute parameter norms over all layers and globally for detailed monitoring
1390
+ layer_param_norm = jax.tree_map(jnp.linalg.norm, new_state.params)
1391
+ logs["layer_param_norm"] = layer_param_norm
1392
+ logs["param_norm"] = jnp.linalg.norm(jax.tree_util.tree_leaves(layer_param_norm))
1393
+
1394
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
1395
+ metrics.update(logs)
1396
+
1397
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1398
+ # metrics = to_fp32(metrics)
1399
+
1400
+ return new_state, metrics
1401
+
1402
+ # Define eval fn
1403
+ def eval_step(params, batch):
1404
+ labels = batch.pop("labels")
1405
+ logits = model(**batch, params=params, train=False)[0]
1406
+
1407
+ logits_mask = model._get_feature_vector_attention_mask(logits.shape[1], batch["attention_mask"])
1408
+ loss = ctc_loss(logits, logits_mask, labels, blank_id, loss_reduction="mean")
1409
+
1410
+ pred_ids = jnp.argmax(logits, axis=-1)
1411
+
1412
+ # summarize metrics
1413
+ metrics = {"loss": loss}
1414
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
1415
+ # metrics = to_fp32(metrics)
1416
+ return metrics, pred_ids
1417
+
1418
+ # Create parallel version of the train and eval step
1419
+ if training_args.do_train:
1420
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
1421
+
1422
+ if training_args.do_eval:
1423
+ p_eval_step = jax.pmap(eval_step, "batch")
1424
+
1425
+ def run_evaluation(step):
1426
+ if training_args.do_eval:
1427
+ # ======================== Evaluating ==============================
1428
+ eval_metrics = []
1429
+ eval_preds = []
1430
+ eval_labels = []
1431
+
1432
+ # Generate eval set by sequentially sampling indices from the eval dataset and grouping by length
1433
+ eval_samples_idx = get_grouped_indices(vectorized_datasets["eval"], eval_batch_size)
1434
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
1435
+
1436
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
1437
+ samples = [vectorized_datasets["eval"][int(idx)] for idx in batch_idx]
1438
+ batch = data_collator(samples)
1439
+ labels = batch["labels"]
1440
+
1441
+ metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1442
+ eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1])))
1443
+ eval_metrics.append(metrics)
1444
+
1445
+ eval_labels.extend(labels)
1446
+
1447
+ # normalize eval metrics
1448
+ eval_metrics = get_metrics(eval_metrics)
1449
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
1450
+ eval_metrics = to_fp32(eval_metrics)
1451
+
1452
+ # always run compute metrics
1453
+ error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels)
1454
+ eval_metrics.update(error_rate_metric)
1455
+ error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()])
1456
+
1457
+ # Print metrics and update progress bar
1458
+ desc = f"Step... ({step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})"
1459
+ epochs.write(desc)
1460
+ epochs.desc = desc
1461
+
1462
+ # Save metrics
1463
+ write_wandb_log(eval_metrics, step, prefix="eval")
1464
+ write_wandb_pred(pred_str, label_str, step)
1465
+ # if has_tensorboard and jax.process_index() == 0:
1466
+ # write_eval_metric(summary_writer, eval_metrics, step, pred_str=pred_str)
1467
+
1468
+ def save_checkpoint(step):
1469
+ # save and push checkpoint to the hub
1470
+ if jax.process_index() == 0:
1471
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
1472
+ model.save_pretrained(training_args.output_dir, params=params)
1473
+ tokenizer.save_pretrained(training_args.output_dir)
1474
+ if training_args.push_to_hub:
1475
+ repo.push_to_hub(commit_message=f"{wandb.run.id}: saving weights and logs of step {int(step / 1000)}k", blocking=False)
1476
+
1477
+ logger.info("***** Running training *****")
1478
+ logger.info(f" Num examples = {num_train_samples}")
1479
+ logger.info(f" Num Epochs = {num_epochs}")
1480
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
1481
+ logger.info(f" Num gradient accumulation steps = {gradient_accumulation_steps}")
1482
+ logger.info(f" Total train batch size (w. parallel & distributed) = {batch_size_per_update}")
1483
+ logger.info(f" Total optimization steps = {total_train_steps}")
1484
+ logger.info(f" Gradient checkpointing: {config.gradient_checkpointing}")
1485
+ logger.info(f" Use scan: {config.use_scan}")
1486
+ logger.info(f" Fuse matmuls: {config.fuse_matmuls}")
1487
+
1488
+ train_time = cur_step = 0
1489
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
1490
+ for epoch in epochs:
1491
+ if training_args.do_train:
1492
+ # ======================== Training ================================
1493
+ train_start = time.time()
1494
+
1495
+ # Create sampling rng
1496
+ rng, input_rng = jax.random.split(rng)
1497
+
1498
+ # Generate an epoch by randomly shuffling sampling indices from the train dataset and grouping by length
1499
+ train_samples_idx = get_grouped_indices(vectorized_datasets["train"], batch_size_per_update, input_rng)
1500
+ train_batch_idx = generate_batch_splits(train_samples_idx, batch_size_per_update)
1501
+
1502
+ # Gather the indices for creating the batch and do a training step
1503
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1), 1):
1504
+ samples = [vectorized_datasets["train"][int(idx)] for idx in batch_idx]
1505
+ batch = data_collator(samples)
1506
+ batch = shard(batch.data)
1507
+ try:
1508
+ state, train_metric = p_train_step(state, batch)
1509
+ except TypeError as e:
1510
+ logger.warning("Encountered following error: \n", e)
1511
+
1512
+ cur_step = epoch * (num_train_samples // batch_size_per_update) + step
1513
+
1514
+ if cur_step % training_args.logging_steps == 0:
1515
+ # Save metrics
1516
+ train_metric = unreplicate(train_metric)
1517
+ train_time += time.time() - train_start
1518
+ # need to upcast all device arrays to fp32 for wandb logging (jnp.bfloat16 not supported) -> do this here OR in train_step
1519
+ write_wandb_log(to_fp32(train_metric), cur_step, prefix="train")
1520
+ # we won't log to tensorboard for now (it is fiddly logging param and grad norms on a layer-by-layer basis)
1521
+ # if has_tensorboard and jax.process_index() == 0:
1522
+ # write_train_metric(summary_writer, train_metrics, train_time, cur_step)
1523
+
1524
+ epochs.write(
1525
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']}, Gradient Norm: {train_metric['grad_norm']})"
1526
+ )
1527
+
1528
+ if cur_step % total_train_steps == 0:
1529
+ break
1530
+
1531
+ if training_args.eval_steps and cur_step % training_args.eval_steps == 0:
1532
+ run_evaluation(cur_step)
1533
+
1534
+ if cur_step % training_args.save_steps == 0:
1535
+ save_checkpoint(cur_step)
1536
+
1537
+ if training_args.eval_steps == 0 and (epoch + 1) != num_epochs:
1538
+ # run evaluation at the end of the epoch if eval steps are not specified
1539
+ run_evaluation(cur_step)
1540
+ save_checkpoint(cur_step)
1541
+
1542
+ if training_args.do_train:
1543
+ save_checkpoint(cur_step)
1544
+
1545
+ cur_step = max_steps if max_steps > 0 else cur_step # set step to max steps so that eval happens in alignment with training
1546
+
1547
+ if training_args.do_eval:
1548
+ run_evaluation(cur_step)
1549
+
1550
+ # TODO: collapse 'do_predict' into the run_evaluation function
1551
+ if training_args.do_predict:
1552
+ for split in test_split:
1553
+ # ======================== Evaluating ==============================
1554
+ eval_metrics = []
1555
+ eval_preds = []
1556
+ eval_labels = []
1557
+
1558
+ # Generate eval set by sequentially sampling indices from the test dataset and grouping by length
1559
+ eval_samples_idx = get_grouped_indices(vectorized_datasets[split], eval_batch_size)
1560
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
1561
+
1562
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc=f"Predicting {split}...", position=2)):
1563
+ samples = [vectorized_datasets[split][int(idx)] for idx in batch_idx]
1564
+ batch = data_collator(samples)
1565
+ labels = batch["labels"]
1566
+
1567
+ metrics, pred_ids = pad_shard_unpad(p_eval_step)(state.params, batch.data, min_device_batch=per_device_eval_batch_size)
1568
+ eval_preds.extend(jax.device_get(pred_ids.reshape(-1, pred_ids.shape[-1])))
1569
+ eval_metrics.append(metrics)
1570
+
1571
+ eval_labels.extend(labels)
1572
+
1573
+ # normalize eval metrics
1574
+ eval_metrics = get_metrics(eval_metrics)
1575
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
1576
+ eval_metrics = to_fp32(eval_metrics)
1577
+
1578
+ # always run compute metrics
1579
+ error_rate_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels)
1580
+ eval_metrics.update(error_rate_metric)
1581
+ error_rate_desc = " ".join([f"Eval {key}: {value} |" for key, value in error_rate_metric.items()])
1582
+
1583
+ # Print metrics and update progress bar
1584
+ desc = f"Step... ({cur_step}/{total_train_steps} | Eval Loss: {eval_metrics['loss']} | {error_rate_desc})"
1585
+ epochs.write(desc)
1586
+ epochs.desc = desc
1587
+
1588
+ # Save metrics
1589
+ write_wandb_log(eval_metrics, cur_step, prefix=split)
1590
+ write_wandb_pred(pred_str, label_str, cur_step, prefix=split)
1591
+ # if has_tensorboard and jax.process_index() == 0:
1592
+ # write_eval_metric(summary_writer, eval_metrics, cur_step, pred_str=pred_str)
1593
+
1594
+
1595
+ if __name__ == "__main__":
1596
+ main()
wandb/run-20220729_184558-17ksemgv/files/config.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_version: 1
2
+
3
+ _wandb:
4
+ desc: null
5
+ value:
6
+ cli_version: 0.12.9
7
+ code_path: code/run_flax_speech_recognition_ctc.py
8
+ framework: huggingface
9
+ huggingface_version: 4.21.0
10
+ is_jupyter_run: false
11
+ is_kaggle_kernel: false
12
+ python_version: 3.8.10
13
+ start_time: 1659120358
14
+ t:
15
+ 1:
16
+ - 1
17
+ - 2
18
+ - 3
19
+ - 11
20
+ - 12
21
+ 2:
22
+ - 1
23
+ - 2
24
+ - 3
25
+ - 11
26
+ - 12
27
+ 3:
28
+ - 13
29
+ 4: 3.8.10
30
+ 5: 0.12.9
31
+ 6: 4.21.0
32
+ 8:
33
+ - 5