ct-vikramanantha commited on
Commit
ecfe08e
·
verified ·
1 Parent(s): a712c49

Uploaded and Created

Browse files
README.md ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ license: apache-2.0
4
+ tags:
5
+ - phoneme-recognition
6
+ - generated_from_trainer
7
+ datasets:
8
+ - w11wo/ljspeech_phonemes
9
+ base_model: Wav2Vec2-Base
10
+ inference:
11
+ parameters:
12
+ function_to_apply: none
13
+ model-index:
14
+ - name: Wav2Vec2 LJSpeech Gruut
15
+ results:
16
+ - task:
17
+ type: automatic-speech-recognition
18
+ name: Automatic Speech Recognition
19
+ dataset:
20
+ name: LJSpeech
21
+ type: ljspeech_phonemes
22
+ metrics:
23
+ - type: per
24
+ value: 0.0099
25
+ name: Test PER (w/o stress)
26
+ - type: cer
27
+ value: 0.0058
28
+ name: Test CER (w/o stress)
29
+ ---
30
+
31
+ # Wav2Vec2 LJSpeech Gruut
32
+
33
+ Clone of wav2vec2-ljspeech-gruut because I want to use pipeline and get the logits from it
34
+
35
+ ## Model
36
+
37
+ | Model | #params | Arch. | Training/Validation data (text) |
38
+ | ------------------------- | ------- | ----------- | ------------------------------- |
39
+ | `wav2vec2-ljspeech-gruut` | 94M | wav2vec 2.0 | `LJSpech Phonemes` Dataset |
40
+
41
+ ## Evaluation Results
42
+
43
+ The model achieves the following results on evaluation:
44
+
45
+ | Dataset | PER (w/o stress) | CER (w/o stress) |
46
+ | ---------------------------- | :--------------: | :--------------: |
47
+ | `LJSpech Phonemes` Test Data | 0.99% | 0.58% |
48
+
49
+ ## Usage
50
+
51
+ ```py
52
+ from transformers import AutoProcessor, AutoModelForCTC, Wav2Vec2Processor
53
+ import librosa
54
+ import torch
55
+ from itertools import groupby
56
+ from datasets import load_dataset
57
+
58
+ def decode_phonemes(
59
+ ids: torch.Tensor, processor: Wav2Vec2Processor, ignore_stress: bool = False
60
+ ) -> str:
61
+ """CTC-like decoding. First removes consecutive duplicates, then removes special tokens."""
62
+ # removes consecutive duplicates
63
+ ids = [id_ for id_, _ in groupby(ids)]
64
+
65
+ special_token_ids = processor.tokenizer.all_special_ids + [
66
+ processor.tokenizer.word_delimiter_token_id
67
+ ]
68
+ # converts id to token, skipping special tokens
69
+ phonemes = [processor.decode(id_) for id_ in ids if id_ not in special_token_ids]
70
+
71
+ # joins phonemes
72
+ prediction = " ".join(phonemes)
73
+
74
+ # whether to ignore IPA stress marks
75
+ if ignore_stress == True:
76
+ prediction = prediction.replace("ˈ", "").replace("ˌ", "")
77
+
78
+ return prediction
79
+
80
+ checkpoint = "bookbot/wav2vec2-ljspeech-gruut"
81
+
82
+ model = AutoModelForCTC.from_pretrained(checkpoint)
83
+ processor = AutoProcessor.from_pretrained(checkpoint)
84
+ sr = processor.feature_extractor.sampling_rate
85
+
86
+ # load dummy dataset and read soundfiles
87
+ ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
88
+ audio_array = ds[0]["audio"]["array"]
89
+
90
+ # or, read a single audio file
91
+ # audio_array, _ = librosa.load("myaudio.wav", sr=sr)
92
+
93
+ inputs = processor(audio_array, return_tensors="pt", padding=True)
94
+
95
+ with torch.no_grad():
96
+ logits = model(inputs["input_values"]).logits
97
+
98
+ predicted_ids = torch.argmax(logits, dim=-1)
99
+ prediction = decode_phonemes(predicted_ids[0], processor, ignore_stress=True)
100
+ # => should give 'b ɪ k ʌ z j u ɚ z s l i p ɪ ŋ ɪ n s t ɛ d ə v k ɔ ŋ k ɚ ɪ ŋ ð ə l ʌ v l i ɹ z p ɹ ɪ n s ə s h æ z b ɪ k ʌ m ə v f ɪ t ə l w ɪ θ n b oʊ p ɹ ə ʃ æ ɡ i s ɪ t s ð ɛ ɹ ə k u ɪ ŋ d ʌ v'
101
+ ```
102
+
103
+ ## Training procedure
104
+
105
+ ### Training hyperparameters
106
+
107
+ The following hyperparameters were used during training:
108
+
109
+ - `learning_rate`: 0.0001
110
+ - `train_batch_size`: 16
111
+ - `eval_batch_size`: 8
112
+ - `seed`: 42
113
+ - `gradient_accumulation_steps`: 2
114
+ - `total_train_batch_size`: 32
115
+ - `optimizer`: Adam with `betas=(0.9,0.999)` and `epsilon=1e-08`
116
+ - `lr_scheduler_type`: linear
117
+ - `lr_scheduler_warmup_steps`: 1000
118
+ - `num_epochs`: 30.0
119
+ - `mixed_precision_training`: Native AMP
120
+
121
+ ### Training results
122
+
123
+ | Training Loss | Epoch | Step | Validation Loss | Wer | Cer |
124
+ | :-----------: | :---: | :---: | :-------------: | :----: | :----: |
125
+ | No log | 1.0 | 348 | 2.2818 | 1.0 | 1.0 |
126
+ | 2.6692 | 2.0 | 696 | 0.2045 | 0.0527 | 0.0299 |
127
+ | 0.2225 | 3.0 | 1044 | 0.1162 | 0.0319 | 0.0189 |
128
+ | 0.2225 | 4.0 | 1392 | 0.0927 | 0.0235 | 0.0147 |
129
+ | 0.0868 | 5.0 | 1740 | 0.0797 | 0.0218 | 0.0143 |
130
+ | 0.0598 | 6.0 | 2088 | 0.0715 | 0.0197 | 0.0128 |
131
+ | 0.0598 | 7.0 | 2436 | 0.0652 | 0.0160 | 0.0103 |
132
+ | 0.0447 | 8.0 | 2784 | 0.0571 | 0.0152 | 0.0095 |
133
+ | 0.0368 | 9.0 | 3132 | 0.0608 | 0.0163 | 0.0112 |
134
+ | 0.0368 | 10.0 | 3480 | 0.0586 | 0.0137 | 0.0083 |
135
+ | 0.0303 | 11.0 | 3828 | 0.0641 | 0.0141 | 0.0085 |
136
+ | 0.0273 | 12.0 | 4176 | 0.0656 | 0.0131 | 0.0079 |
137
+ | 0.0232 | 13.0 | 4524 | 0.0690 | 0.0133 | 0.0082 |
138
+ | 0.0232 | 14.0 | 4872 | 0.0598 | 0.0128 | 0.0079 |
139
+ | 0.0189 | 15.0 | 5220 | 0.0671 | 0.0121 | 0.0074 |
140
+ | 0.017 | 16.0 | 5568 | 0.0654 | 0.0114 | 0.0069 |
141
+ | 0.017 | 17.0 | 5916 | 0.0751 | 0.0118 | 0.0073 |
142
+ | 0.0146 | 18.0 | 6264 | 0.0653 | 0.0112 | 0.0068 |
143
+ | 0.0127 | 19.0 | 6612 | 0.0682 | 0.0112 | 0.0069 |
144
+ | 0.0127 | 20.0 | 6960 | 0.0678 | 0.0114 | 0.0068 |
145
+ | 0.0114 | 21.0 | 7308 | 0.0656 | 0.0111 | 0.0066 |
146
+ | 0.0101 | 22.0 | 7656 | 0.0669 | 0.0109 | 0.0066 |
147
+ | 0.0092 | 23.0 | 8004 | 0.0677 | 0.0108 | 0.0065 |
148
+ | 0.0092 | 24.0 | 8352 | 0.0653 | 0.0104 | 0.0063 |
149
+ | 0.0088 | 25.0 | 8700 | 0.0673 | 0.0102 | 0.0063 |
150
+ | 0.0074 | 26.0 | 9048 | 0.0669 | 0.0105 | 0.0064 |
151
+ | 0.0074 | 27.0 | 9396 | 0.0707 | 0.0101 | 0.0061 |
152
+ | 0.0066 | 28.0 | 9744 | 0.0673 | 0.0100 | 0.0060 |
153
+ | 0.0058 | 29.0 | 10092 | 0.0689 | 0.0100 | 0.0059 |
154
+ | 0.0058 | 30.0 | 10440 | 0.0683 | 0.0099 | 0.0058 |
155
+
156
+
157
+ ## Disclaimer
158
+
159
+ Do consider the biases which came from pre-training datasets that may be carried over into the results of this model.
160
+
161
+ ## Authors
162
+
163
+ Wav2Vec2 LJSpeech Gruut was trained and evaluated by [Wilson Wongso](https://w11wo.github.io/). All computation and development are done on Google Cloud.
164
+
165
+ ## Framework versions
166
+
167
+ - Transformers 4.26.0.dev0
168
+ - Pytorch 1.10.0
169
+ - Datasets 2.7.1
170
+ - Tokenizers 0.13.2
171
+ - Gruut 2.3.4
added_tokens.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "</s>": 44,
3
+ "<s>": 43
4
+ }
all_results.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 30.0,
3
+ "eval_cer": 0.0058143100249896515,
4
+ "eval_loss": 0.06832413375377655,
5
+ "eval_runtime": 87.0708,
6
+ "eval_samples": 1965,
7
+ "eval_samples_per_second": 22.568,
8
+ "eval_steps_per_second": 2.825,
9
+ "eval_wer": 0.009874807524938073,
10
+ "train_loss": 0.15939651108792915,
11
+ "train_runtime": 19958.548,
12
+ "train_samples": 11135,
13
+ "train_samples_per_second": 16.737,
14
+ "train_steps_per_second": 0.523
15
+ }
config.json ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "facebook/wav2vec2-base",
3
+ "activation_dropout": 0.0,
4
+ "adapter_kernel_size": 3,
5
+ "adapter_stride": 2,
6
+ "add_adapter": false,
7
+ "apply_spec_augment": true,
8
+ "architectures": [
9
+ "Wav2Vec2ForCTC"
10
+ ],
11
+ "attention_dropout": 0.0,
12
+ "bos_token_id": 1,
13
+ "classifier_proj_size": 256,
14
+ "codevector_dim": 256,
15
+ "contrastive_logits_temperature": 0.1,
16
+ "conv_bias": false,
17
+ "conv_dim": [
18
+ 512,
19
+ 512,
20
+ 512,
21
+ 512,
22
+ 512,
23
+ 512,
24
+ 512
25
+ ],
26
+ "conv_kernel": [
27
+ 10,
28
+ 3,
29
+ 3,
30
+ 3,
31
+ 3,
32
+ 2,
33
+ 2
34
+ ],
35
+ "conv_stride": [
36
+ 5,
37
+ 2,
38
+ 2,
39
+ 2,
40
+ 2,
41
+ 2,
42
+ 2
43
+ ],
44
+ "ctc_loss_reduction": "mean",
45
+ "ctc_zero_infinity": false,
46
+ "diversity_loss_weight": 0.1,
47
+ "do_stable_layer_norm": false,
48
+ "eos_token_id": 2,
49
+ "feat_extract_activation": "gelu",
50
+ "feat_extract_norm": "group",
51
+ "feat_proj_dropout": 0.0,
52
+ "feat_quantizer_dropout": 0.0,
53
+ "final_dropout": 0.0,
54
+ "freeze_feat_extract_train": true,
55
+ "hidden_act": "gelu",
56
+ "hidden_dropout": 0.0,
57
+ "hidden_size": 768,
58
+ "initializer_range": 0.02,
59
+ "intermediate_size": 3072,
60
+ "layer_norm_eps": 1e-05,
61
+ "layerdrop": 0.0,
62
+ "mask_channel_length": 10,
63
+ "mask_channel_min_space": 1,
64
+ "mask_channel_other": 0.0,
65
+ "mask_channel_prob": 0.0,
66
+ "mask_channel_selection": "static",
67
+ "mask_feature_length": 10,
68
+ "mask_feature_min_masks": 0,
69
+ "mask_feature_prob": 0.0,
70
+ "mask_time_length": 10,
71
+ "mask_time_min_masks": 2,
72
+ "mask_time_min_space": 1,
73
+ "mask_time_other": 0.0,
74
+ "mask_time_prob": 0.05,
75
+ "mask_time_selection": "static",
76
+ "model_type": "wav2vec2",
77
+ "no_mask_channel_overlap": false,
78
+ "no_mask_time_overlap": false,
79
+ "num_adapter_layers": 3,
80
+ "num_attention_heads": 12,
81
+ "num_codevector_groups": 2,
82
+ "num_codevectors_per_group": 320,
83
+ "num_conv_pos_embedding_groups": 16,
84
+ "num_conv_pos_embeddings": 128,
85
+ "num_feat_extract_layers": 7,
86
+ "num_hidden_layers": 12,
87
+ "num_negatives": 100,
88
+ "output_hidden_size": 768,
89
+ "pad_token_id": 42,
90
+ "proj_codevector_dim": 256,
91
+ "tdnn_dilation": [
92
+ 1,
93
+ 2,
94
+ 3,
95
+ 1,
96
+ 1
97
+ ],
98
+ "tdnn_dim": [
99
+ 512,
100
+ 512,
101
+ 512,
102
+ 512,
103
+ 1500
104
+ ],
105
+ "tdnn_kernel": [
106
+ 5,
107
+ 3,
108
+ 3,
109
+ 1,
110
+ 1
111
+ ],
112
+ "torch_dtype": "float32",
113
+ "transformers_version": "4.26.0.dev0",
114
+ "use_weighted_layer_sum": false,
115
+ "vocab_size": 45,
116
+ "xvector_output_dim": 512
117
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:36094944bb26c933f05f2d3e40e14ae495479700380597ce11e1546c6df08c93
3
+ size 134
preprocessor_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "feature_extractor_type": "Wav2Vec2FeatureExtractor",
4
+ "feature_size": 1,
5
+ "padding_side": "right",
6
+ "padding_value": 0.0,
7
+ "return_attention_mask": false,
8
+ "sampling_rate": 16000
9
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:70b3f3f15a1a1f2ca80f9af5764e9b9098a0a7b7c5a6a43494edd2fcd9503dec
3
+ size 134
run.sh ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python run_speech_recognition_ctc.py \
2
+ --dataset_name="w11wo/ljspeech_phonemes" \
3
+ --text_column_name="phonemes" \
4
+ --train_split_name="train" \
5
+ --model_name_or_path="facebook/wav2vec2-base" \
6
+ --output_dir="./wav2vec2-ljspeech-gruut" \
7
+ --overwrite_output_dir \
8
+ --num_train_epochs="30" \
9
+ --per_device_train_batch_size="16" \
10
+ --gradient_accumulation_steps="2" \
11
+ --learning_rate="1e-4" \
12
+ --warmup_steps="1000" \
13
+ --weight_decay="0.005" \
14
+ --evaluation_strategy="epoch" \
15
+ --eval_metrics wer cer \
16
+ --save_strategy="epoch" \
17
+ --layerdrop="0.0" \
18
+ --save_total_limit="3" \
19
+ --freeze_feature_encoder \
20
+ --gradient_checkpointing \
21
+ --chars_to_ignore , ? . ! - \; \: \" “ % ‘ ” � ˈ ˌ \
22
+ --fp16 \
23
+ --group_by_length \
24
+ --report_to="tensorboard" \
25
+ --push_to_hub \
26
+ --do_train --do_eval \
27
+ --hub_model_id="bookbot/wav2vec2-ljspeech-gruut" \
28
+ --hub_private_repo="True" \
29
+ --use_auth_token="True"
run_speech_recognition_ctc.py ADDED
@@ -0,0 +1,856 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ """ Fine-tuning a 🤗 Transformers CTC model for automatic speech recognition"""
17
+
18
+ import functools
19
+ import json
20
+ import logging
21
+ import os
22
+ import re
23
+ import sys
24
+ import warnings
25
+ from dataclasses import dataclass, field
26
+ from typing import Dict, List, Optional, Union
27
+
28
+ import datasets
29
+ import numpy as np
30
+ import torch
31
+ from datasets import DatasetDict, load_dataset
32
+
33
+ import evaluate
34
+ import transformers
35
+ from transformers import (
36
+ AutoConfig,
37
+ AutoFeatureExtractor,
38
+ AutoModelForCTC,
39
+ AutoProcessor,
40
+ AutoTokenizer,
41
+ HfArgumentParser,
42
+ Trainer,
43
+ TrainingArguments,
44
+ Wav2Vec2Processor,
45
+ set_seed,
46
+ )
47
+ from transformers.trainer_utils import get_last_checkpoint, is_main_process
48
+ from transformers.utils import check_min_version, send_example_telemetry
49
+ from transformers.utils.versions import require_version
50
+
51
+
52
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
53
+ check_min_version("4.26.0.dev0")
54
+
55
+ require_version(
56
+ "datasets>=1.18.0",
57
+ "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt",
58
+ )
59
+
60
+
61
+ logger = logging.getLogger(__name__)
62
+
63
+
64
+ def list_field(default=None, metadata=None):
65
+ return field(default_factory=lambda: default, metadata=metadata)
66
+
67
+
68
+ @dataclass
69
+ class ModelArguments:
70
+ """
71
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
72
+ """
73
+
74
+ model_name_or_path: str = field(
75
+ metadata={
76
+ "help": "Path to pretrained model or model identifier from huggingface.co/models"
77
+ }
78
+ )
79
+ tokenizer_name_or_path: Optional[str] = field(
80
+ default=None,
81
+ metadata={
82
+ "help": "Path to pretrained tokenizer or tokenizer identifier from huggingface.co/models"
83
+ },
84
+ )
85
+ cache_dir: Optional[str] = field(
86
+ default=None,
87
+ metadata={
88
+ "help": "Where do you want to store the pretrained models downloaded from huggingface.co"
89
+ },
90
+ )
91
+ freeze_feature_encoder: bool = field(
92
+ default=True,
93
+ metadata={"help": "Whether to freeze the feature encoder layers of the model."},
94
+ )
95
+ attention_dropout: float = field(
96
+ default=0.0,
97
+ metadata={"help": "The dropout ratio for the attention probabilities."},
98
+ )
99
+ activation_dropout: float = field(
100
+ default=0.0,
101
+ metadata={
102
+ "help": "The dropout ratio for activations inside the fully connected layer."
103
+ },
104
+ )
105
+ feat_proj_dropout: float = field(
106
+ default=0.0, metadata={"help": "The dropout ratio for the projected features."}
107
+ )
108
+ hidden_dropout: float = field(
109
+ default=0.0,
110
+ metadata={
111
+ "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler."
112
+ },
113
+ )
114
+ final_dropout: float = field(
115
+ default=0.0,
116
+ metadata={"help": "The dropout probability for the final projection layer."},
117
+ )
118
+ mask_time_prob: float = field(
119
+ default=0.05,
120
+ metadata={
121
+ "help": (
122
+ "Probability of each feature vector along the time axis to be chosen as the start of the vector"
123
+ "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
124
+ "vectors will be masked along the time axis."
125
+ )
126
+ },
127
+ )
128
+ mask_time_length: int = field(
129
+ default=10,
130
+ metadata={"help": "Length of vector span to mask along the time axis."},
131
+ )
132
+ mask_feature_prob: float = field(
133
+ default=0.0,
134
+ metadata={
135
+ "help": (
136
+ "Probability of each feature vector along the feature axis to be chosen as the start of the vectorspan"
137
+ " to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature"
138
+ " bins will be masked along the time axis."
139
+ )
140
+ },
141
+ )
142
+ mask_feature_length: int = field(
143
+ default=10,
144
+ metadata={"help": "Length of vector span to mask along the feature axis."},
145
+ )
146
+ layerdrop: float = field(
147
+ default=0.0, metadata={"help": "The LayerDrop probability."}
148
+ )
149
+ ctc_loss_reduction: Optional[str] = field(
150
+ default="mean",
151
+ metadata={
152
+ "help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."
153
+ },
154
+ )
155
+
156
+
157
+ @dataclass
158
+ class DataTrainingArguments:
159
+ """
160
+ Arguments pertaining to what data we are going to input our model for training and eval.
161
+
162
+ Using `HfArgumentParser` we can turn this class
163
+ into argparse arguments to be able to specify them on
164
+ the command line.
165
+ """
166
+
167
+ dataset_name: str = field(
168
+ metadata={
169
+ "help": "The configuration name of the dataset to use (via the datasets library)."
170
+ }
171
+ )
172
+ dataset_config_name: str = field(
173
+ default=None,
174
+ metadata={
175
+ "help": "The configuration name of the dataset to use (via the datasets library)."
176
+ },
177
+ )
178
+ train_split_name: str = field(
179
+ default="train+validation",
180
+ metadata={
181
+ "help": (
182
+ "The name of the training data set split to use (via the datasets library). Defaults to "
183
+ "'train+validation'"
184
+ )
185
+ },
186
+ )
187
+ eval_split_name: str = field(
188
+ default="test",
189
+ metadata={
190
+ "help": "The name of the evaluation data set split to use (via the datasets library). Defaults to 'test'"
191
+ },
192
+ )
193
+ audio_column_name: str = field(
194
+ default="audio",
195
+ metadata={
196
+ "help": "The name of the dataset column containing the audio data. Defaults to 'audio'"
197
+ },
198
+ )
199
+ text_column_name: str = field(
200
+ default="text",
201
+ metadata={
202
+ "help": "The name of the dataset column containing the text data. Defaults to 'text'"
203
+ },
204
+ )
205
+ overwrite_cache: bool = field(
206
+ default=False,
207
+ metadata={"help": "Overwrite the cached preprocessed datasets or not."},
208
+ )
209
+ preprocessing_num_workers: Optional[int] = field(
210
+ default=None,
211
+ metadata={"help": "The number of processes to use for the preprocessing."},
212
+ )
213
+ max_train_samples: Optional[int] = field(
214
+ default=None,
215
+ metadata={
216
+ "help": (
217
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
218
+ "value if set."
219
+ )
220
+ },
221
+ )
222
+ max_eval_samples: Optional[int] = field(
223
+ default=None,
224
+ metadata={
225
+ "help": (
226
+ "For debugging purposes or quicker training, truncate the number of validation examples to this "
227
+ "value if set."
228
+ )
229
+ },
230
+ )
231
+ chars_to_ignore: Optional[List[str]] = list_field(
232
+ default=None,
233
+ metadata={"help": "A list of characters to remove from the transcripts."},
234
+ )
235
+ eval_metrics: List[str] = list_field(
236
+ default=["wer"],
237
+ metadata={
238
+ "help": "A list of metrics the model should be evaluated on. E.g. `'wer cer'`"
239
+ },
240
+ )
241
+ max_duration_in_seconds: float = field(
242
+ default=20.0,
243
+ metadata={
244
+ "help": (
245
+ "Filter audio files that are longer than `max_duration_in_seconds` seconds to"
246
+ " 'max_duration_in_seconds`"
247
+ )
248
+ },
249
+ )
250
+ min_duration_in_seconds: float = field(
251
+ default=0.0,
252
+ metadata={
253
+ "help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"
254
+ },
255
+ )
256
+ preprocessing_only: bool = field(
257
+ default=False,
258
+ metadata={
259
+ "help": (
260
+ "Whether to only do data preprocessing and skip training. This is especially useful when data"
261
+ " preprocessing errors out in distributed training due to timeout. In this case, one should run the"
262
+ " preprocessing in a non-distributed setup with `preprocessing_only=True` so that the cached datasets"
263
+ " can consequently be loaded in distributed training"
264
+ )
265
+ },
266
+ )
267
+ use_auth_token: bool = field(
268
+ default=False,
269
+ metadata={
270
+ "help": (
271
+ "If :obj:`True`, will use the token generated when running"
272
+ ":obj:`huggingface-cli login` as HTTP bearer authorization for remote files."
273
+ )
274
+ },
275
+ )
276
+ unk_token: str = field(
277
+ default="[UNK]",
278
+ metadata={"help": "The unk token for the tokenizer"},
279
+ )
280
+ pad_token: str = field(
281
+ default="[PAD]",
282
+ metadata={"help": "The padding token for the tokenizer"},
283
+ )
284
+ word_delimiter_token: str = field(
285
+ default="|",
286
+ metadata={"help": "The word delimiter token for the tokenizer"},
287
+ )
288
+ phoneme_language: Optional[str] = field(
289
+ default=None,
290
+ metadata={
291
+ "help": (
292
+ "The target language that should be used be"
293
+ " passed to the tokenizer for tokenization. Note that"
294
+ " this is only relevant if the model classifies the"
295
+ " input audio to a sequence of phoneme sequences."
296
+ )
297
+ },
298
+ )
299
+
300
+
301
+ @dataclass
302
+ class DataCollatorCTCWithPadding:
303
+ """
304
+ Data collator that will dynamically pad the inputs received.
305
+ Args:
306
+ processor (:class:`~transformers.AutoProcessor`)
307
+ The processor used for proccessing the data.
308
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
309
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
310
+ among:
311
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
312
+ sequence if provided).
313
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
314
+ maximum acceptable input length for the model if that argument is not provided.
315
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
316
+ different lengths).
317
+ max_length (:obj:`int`, `optional`):
318
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
319
+ max_length_labels (:obj:`int`, `optional`):
320
+ Maximum length of the ``labels`` returned list and optionally padding length (see above).
321
+ pad_to_multiple_of (:obj:`int`, `optional`):
322
+ If set will pad the sequence to a multiple of the provided value.
323
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
324
+ 7.5 (Volta).
325
+ """
326
+
327
+ processor: AutoProcessor
328
+ padding: Union[bool, str] = "longest"
329
+ pad_to_multiple_of: Optional[int] = None
330
+ pad_to_multiple_of_labels: Optional[int] = None
331
+
332
+ def __call__(
333
+ self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
334
+ ) -> Dict[str, torch.Tensor]:
335
+ # split inputs and labels since they have to be of different lenghts and need
336
+ # different padding methods
337
+ input_features = [
338
+ {"input_values": feature["input_values"]} for feature in features
339
+ ]
340
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
341
+
342
+ batch = self.processor.pad(
343
+ input_features,
344
+ padding=self.padding,
345
+ pad_to_multiple_of=self.pad_to_multiple_of,
346
+ return_tensors="pt",
347
+ )
348
+
349
+ labels_batch = self.processor.pad(
350
+ labels=label_features,
351
+ padding=self.padding,
352
+ pad_to_multiple_of=self.pad_to_multiple_of_labels,
353
+ return_tensors="pt",
354
+ )
355
+
356
+ # replace padding with -100 to ignore loss correctly
357
+ labels = labels_batch["input_ids"].masked_fill(
358
+ labels_batch.attention_mask.ne(1), -100
359
+ )
360
+
361
+ batch["labels"] = labels
362
+ if "attention_mask" in batch:
363
+ batch["attention_mask"] = batch["attention_mask"].to(torch.long)
364
+
365
+ return batch
366
+
367
+
368
+ def create_vocabulary_from_data(
369
+ datasets: DatasetDict,
370
+ word_delimiter_token: Optional[str] = None,
371
+ unk_token: Optional[str] = None,
372
+ pad_token: Optional[str] = None,
373
+ ):
374
+ # Given training and test labels create vocabulary
375
+ def extract_all_chars(batch):
376
+ all_text = " ".join(batch["target_text"])
377
+ # phonemes are split by whitespace
378
+ vocab = list(set(all_text.split())) + [" "]
379
+ return {"vocab": [vocab], "all_text": [all_text]}
380
+
381
+ vocabs = datasets.map(
382
+ extract_all_chars,
383
+ batched=True,
384
+ batch_size=-1,
385
+ keep_in_memory=True,
386
+ remove_columns=datasets["train"].column_names,
387
+ )
388
+
389
+ # take union of all unique characters in each dataset
390
+ vocab_set = functools.reduce(
391
+ lambda vocab_1, vocab_2: set(vocab_1["vocab"][0]) | set(vocab_2["vocab"][0]),
392
+ vocabs.values(),
393
+ )
394
+
395
+ vocab_dict = {v: k for k, v in enumerate(sorted(list(vocab_set)))}
396
+
397
+ # replace white space with delimiter token
398
+ if word_delimiter_token is not None:
399
+ vocab_dict[word_delimiter_token] = vocab_dict[" "]
400
+ del vocab_dict[" "]
401
+
402
+ # add unk and pad token
403
+ if unk_token is not None:
404
+ vocab_dict[unk_token] = len(vocab_dict)
405
+
406
+ if pad_token is not None:
407
+ vocab_dict[pad_token] = len(vocab_dict)
408
+
409
+ return vocab_dict
410
+
411
+
412
+ def main():
413
+ # See all possible arguments in src/transformers/training_args.py
414
+ # or by passing the --help flag to this script.
415
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
416
+
417
+ parser = HfArgumentParser(
418
+ (ModelArguments, DataTrainingArguments, TrainingArguments)
419
+ )
420
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
421
+ # If we pass only one argument to the script and it's the path to a json file,
422
+ # let's parse it to get our arguments.
423
+ model_args, data_args, training_args = parser.parse_json_file(
424
+ json_file=os.path.abspath(sys.argv[1])
425
+ )
426
+ else:
427
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
428
+
429
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
430
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
431
+ send_example_telemetry("run_speech_recognition_ctc", model_args, data_args)
432
+
433
+ # Detecting last checkpoint.
434
+ last_checkpoint = None
435
+ if (
436
+ os.path.isdir(training_args.output_dir)
437
+ and training_args.do_train
438
+ and not training_args.overwrite_output_dir
439
+ ):
440
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
441
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
442
+ raise ValueError(
443
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
444
+ "Use --overwrite_output_dir to overcome."
445
+ )
446
+ elif last_checkpoint is not None:
447
+ logger.info(
448
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
449
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
450
+ )
451
+
452
+ # Setup logging
453
+ logging.basicConfig(
454
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
455
+ datefmt="%m/%d/%Y %H:%M:%S",
456
+ handlers=[logging.StreamHandler(sys.stdout)],
457
+ )
458
+ logger.setLevel(
459
+ logging.INFO if is_main_process(training_args.local_rank) else logging.WARN
460
+ )
461
+
462
+ # Log on each process the small summary:
463
+ logger.warning(
464
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
465
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
466
+ )
467
+ # Set the verbosity to info of the Transformers logger (on main process only):
468
+ if is_main_process(training_args.local_rank):
469
+ transformers.utils.logging.set_verbosity_info()
470
+ logger.info("Training/evaluation parameters %s", training_args)
471
+
472
+ # Set seed before initializing model.
473
+ set_seed(training_args.seed)
474
+
475
+ # 1. First, let's load the dataset
476
+ raw_datasets = load_dataset(
477
+ data_args.dataset_name,
478
+ data_args.dataset_config_name,
479
+ split=data_args.train_split_name,
480
+ )
481
+
482
+ raw_datasets = raw_datasets.train_test_split(test_size=0.15)
483
+
484
+ if training_args.do_train:
485
+ if data_args.audio_column_name not in raw_datasets["train"].column_names:
486
+ raise ValueError(
487
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'."
488
+ " Make sure to set `--audio_column_name` to the correct audio column - one of"
489
+ f" {', '.join(raw_datasets['train'].column_names)}."
490
+ )
491
+
492
+ if data_args.text_column_name not in raw_datasets["train"].column_names:
493
+ raise ValueError(
494
+ f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
495
+ "Make sure to set `--text_column_name` to the correct text column - one of "
496
+ f"{', '.join(raw_datasets['train'].column_names)}."
497
+ )
498
+
499
+ if data_args.max_train_samples is not None:
500
+ raw_datasets["train"] = raw_datasets["train"].select(
501
+ range(data_args.max_train_samples)
502
+ )
503
+
504
+ if training_args.do_eval:
505
+ if data_args.max_eval_samples is not None:
506
+ raw_datasets["test"] = raw_datasets["test"].select(
507
+ range(data_args.max_eval_samples)
508
+ )
509
+
510
+ # 2. We remove some special characters from the datasets
511
+ # that make training complicated and do not help in transcribing the speech
512
+ # E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
513
+ # that could be easily picked up by the model
514
+ chars_to_ignore_regex = (
515
+ f'[{"".join(data_args.chars_to_ignore)}]'
516
+ if data_args.chars_to_ignore is not None
517
+ else None
518
+ )
519
+ text_column_name = data_args.text_column_name
520
+
521
+ def remove_special_characters(batch):
522
+ if chars_to_ignore_regex is not None:
523
+ batch["target_text"] = (
524
+ re.sub(chars_to_ignore_regex, "", batch[text_column_name]).lower() + " "
525
+ )
526
+ else:
527
+ batch["target_text"] = batch[text_column_name].lower() + " "
528
+ return batch
529
+
530
+ with training_args.main_process_first(
531
+ desc="dataset map special characters removal"
532
+ ):
533
+ raw_datasets = raw_datasets.map(
534
+ remove_special_characters,
535
+ remove_columns=[text_column_name],
536
+ desc="remove special characters from datasets",
537
+ )
538
+
539
+ # save special tokens for tokenizer
540
+ word_delimiter_token = data_args.word_delimiter_token
541
+ unk_token = data_args.unk_token
542
+ pad_token = data_args.pad_token
543
+
544
+ # 3. Next, let's load the config as we might need it to create
545
+ # the tokenizer
546
+ # load config
547
+ config = AutoConfig.from_pretrained(
548
+ model_args.model_name_or_path,
549
+ cache_dir=model_args.cache_dir,
550
+ )
551
+
552
+ # 4. Next, if no tokenizer file is defined,
553
+ # we create the vocabulary of the model by extracting all unique characters from
554
+ # the training and evaluation datasets
555
+ # We need to make sure that only first rank saves vocabulary
556
+ # make sure all processes wait until vocab is created
557
+ tokenizer_name_or_path = model_args.tokenizer_name_or_path
558
+ tokenizer_kwargs = {}
559
+ if tokenizer_name_or_path is None:
560
+ # save vocab in training output dir
561
+ tokenizer_name_or_path = training_args.output_dir
562
+
563
+ vocab_file = os.path.join(tokenizer_name_or_path, "vocab.json")
564
+
565
+ with training_args.main_process_first():
566
+ if training_args.overwrite_output_dir and os.path.isfile(vocab_file):
567
+ try:
568
+ os.remove(vocab_file)
569
+ except OSError:
570
+ # in shared file-systems it might be the case that
571
+ # two processes try to delete the vocab file at the some time
572
+ pass
573
+
574
+ with training_args.main_process_first(desc="dataset map vocabulary creation"):
575
+ if not os.path.isfile(vocab_file):
576
+ os.makedirs(tokenizer_name_or_path, exist_ok=True)
577
+ vocab_dict = create_vocabulary_from_data(
578
+ raw_datasets,
579
+ word_delimiter_token=word_delimiter_token,
580
+ unk_token=unk_token,
581
+ pad_token=pad_token,
582
+ )
583
+
584
+ # save vocab dict to be loaded into tokenizer
585
+ with open(vocab_file, "w") as file:
586
+ json.dump(vocab_dict, file)
587
+
588
+ # if tokenizer has just been created
589
+ # it is defined by `tokenizer_class` if present in config else by `model_type`
590
+ tokenizer_kwargs = {
591
+ "config": config if config.tokenizer_class is not None else None,
592
+ "tokenizer_type": config.model_type
593
+ if config.tokenizer_class is None
594
+ else None,
595
+ "unk_token": unk_token,
596
+ "pad_token": pad_token,
597
+ "word_delimiter_token": word_delimiter_token,
598
+ }
599
+
600
+ # 5. Now we can instantiate the feature extractor, tokenizer and model
601
+ # Note for distributed training, the .from_pretrained methods guarantee that only
602
+ # one local process can concurrently download model & vocab.
603
+
604
+ # load feature_extractor and tokenizer
605
+ tokenizer = AutoTokenizer.from_pretrained(
606
+ tokenizer_name_or_path,
607
+ **tokenizer_kwargs,
608
+ )
609
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
610
+ model_args.model_name_or_path,
611
+ cache_dir=model_args.cache_dir,
612
+ )
613
+
614
+ # adapt config
615
+ config.update(
616
+ {
617
+ "feat_proj_dropout": model_args.feat_proj_dropout,
618
+ "attention_dropout": model_args.attention_dropout,
619
+ "hidden_dropout": model_args.hidden_dropout,
620
+ "final_dropout": model_args.final_dropout,
621
+ "mask_time_prob": model_args.mask_time_prob,
622
+ "mask_time_length": model_args.mask_time_length,
623
+ "mask_feature_prob": model_args.mask_feature_prob,
624
+ "mask_feature_length": model_args.mask_feature_length,
625
+ "gradient_checkpointing": training_args.gradient_checkpointing,
626
+ "layerdrop": model_args.layerdrop,
627
+ "ctc_loss_reduction": model_args.ctc_loss_reduction,
628
+ "pad_token_id": tokenizer.pad_token_id,
629
+ "vocab_size": len(tokenizer),
630
+ "activation_dropout": model_args.activation_dropout,
631
+ }
632
+ )
633
+
634
+ # create model
635
+ model = AutoModelForCTC.from_pretrained(
636
+ model_args.model_name_or_path,
637
+ cache_dir=model_args.cache_dir,
638
+ config=config,
639
+ )
640
+
641
+ # freeze encoder
642
+ if model_args.freeze_feature_encoder:
643
+ model.freeze_feature_encoder()
644
+
645
+ # 6. Now we preprocess the datasets including loading the audio, resampling and normalization
646
+ # Thankfully, `datasets` takes care of automatically loading and resampling the audio,
647
+ # so that we just need to set the correct target sampling rate and normalize the input
648
+ # via the `feature_extractor`
649
+
650
+ # make sure that dataset decodes audio with correct sampling rate
651
+ dataset_sampling_rate = (
652
+ next(iter(raw_datasets.values()))
653
+ .features[data_args.audio_column_name]
654
+ .sampling_rate
655
+ )
656
+ if dataset_sampling_rate != feature_extractor.sampling_rate:
657
+ raw_datasets = raw_datasets.cast_column(
658
+ data_args.audio_column_name,
659
+ datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate),
660
+ )
661
+
662
+ # derive max & min input length for sample rate & max duration
663
+ max_input_length = (
664
+ data_args.max_duration_in_seconds * feature_extractor.sampling_rate
665
+ )
666
+ min_input_length = (
667
+ data_args.min_duration_in_seconds * feature_extractor.sampling_rate
668
+ )
669
+ audio_column_name = data_args.audio_column_name
670
+ num_workers = data_args.preprocessing_num_workers
671
+
672
+ # `phoneme_language` is only relevant if the model is fine-tuned on phoneme classification
673
+ phoneme_language = data_args.phoneme_language
674
+
675
+ # Preprocessing the datasets.
676
+ # We need to read the audio files as arrays and tokenize the targets.
677
+ def prepare_dataset(batch):
678
+ # load audio
679
+ sample = batch[audio_column_name]
680
+
681
+ inputs = feature_extractor(
682
+ sample["array"], sampling_rate=sample["sampling_rate"]
683
+ )
684
+ batch["input_values"] = inputs.input_values[0]
685
+ batch["input_length"] = len(batch["input_values"])
686
+
687
+ # encode targets
688
+ additional_kwargs = {}
689
+ if phoneme_language is not None:
690
+ additional_kwargs["phonemizer_lang"] = phoneme_language
691
+
692
+ batch["labels"] = tokenizer(batch["target_text"], **additional_kwargs).input_ids
693
+ return batch
694
+
695
+ with training_args.main_process_first(desc="dataset map preprocessing"):
696
+ vectorized_datasets = raw_datasets.map(
697
+ prepare_dataset,
698
+ remove_columns=next(iter(raw_datasets.values())).column_names,
699
+ num_proc=num_workers,
700
+ desc="preprocess datasets",
701
+ )
702
+
703
+ def is_audio_in_length_range(length):
704
+ return length > min_input_length and length < max_input_length
705
+
706
+ # filter data that is shorter than min_input_length
707
+ vectorized_datasets = vectorized_datasets.filter(
708
+ is_audio_in_length_range,
709
+ num_proc=num_workers,
710
+ input_columns=["input_length"],
711
+ )
712
+
713
+ # 7. Next, we can prepare the training.
714
+ # Let's use word error rate (WER) as our evaluation metric,
715
+ # instantiate a data collator and the trainer
716
+
717
+ # Define evaluation metrics during training, *i.e.* word error rate, character error rate
718
+ eval_metrics = {metric: evaluate.load(metric) for metric in data_args.eval_metrics}
719
+
720
+ # for large datasets it is advised to run the preprocessing on a
721
+ # single machine first with ``args.preprocessing_only`` since there will mostly likely
722
+ # be a timeout when running the script in distributed mode.
723
+ # In a second step ``args.preprocessing_only`` can then be set to `False` to load the
724
+ # cached dataset
725
+ if data_args.preprocessing_only:
726
+ logger.info(
727
+ f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}"
728
+ )
729
+ return
730
+
731
+ def compute_metrics(pred):
732
+ pred_logits = pred.predictions
733
+ pred_ids = np.argmax(pred_logits, axis=-1)
734
+
735
+ pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id
736
+
737
+ pred_str = tokenizer.batch_decode(pred_ids)
738
+ # we do not want to group tokens when computing the metrics
739
+ label_str = tokenizer.batch_decode(pred.label_ids, group_tokens=False)
740
+
741
+ metrics = {
742
+ k: v.compute(predictions=pred_str, references=label_str)
743
+ for k, v in eval_metrics.items()
744
+ }
745
+
746
+ return metrics
747
+
748
+ # Now save everything to be able to create a single processor later
749
+ if is_main_process(training_args.local_rank):
750
+ # save feature extractor, tokenizer and config
751
+ feature_extractor.save_pretrained(training_args.output_dir)
752
+ tokenizer.save_pretrained(training_args.output_dir)
753
+ config.save_pretrained(training_args.output_dir)
754
+
755
+ try:
756
+ processor = AutoProcessor.from_pretrained(training_args.output_dir)
757
+ except (OSError, KeyError):
758
+ warnings.warn(
759
+ "Loading a processor from a feature extractor config that does not"
760
+ " include a `processor_class` attribute is deprecated and will be removed in v5. Please add the following "
761
+ " attribute to your `preprocessor_config.json` file to suppress this warning: "
762
+ " `'processor_class': 'Wav2Vec2Processor'`",
763
+ FutureWarning,
764
+ )
765
+ processor = Wav2Vec2Processor.from_pretrained(training_args.output_dir)
766
+
767
+ # Instantiate custom data collator
768
+ data_collator = DataCollatorCTCWithPadding(processor=processor)
769
+
770
+ # Initialize Trainer
771
+ trainer = Trainer(
772
+ model=model,
773
+ data_collator=data_collator,
774
+ args=training_args,
775
+ compute_metrics=compute_metrics,
776
+ train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
777
+ eval_dataset=vectorized_datasets["test"] if training_args.do_eval else None,
778
+ tokenizer=feature_extractor,
779
+ )
780
+
781
+ # 8. Finally, we can start training
782
+
783
+ # Training
784
+ if training_args.do_train:
785
+
786
+ # use last checkpoint if exist
787
+ if last_checkpoint is not None:
788
+ checkpoint = last_checkpoint
789
+ elif os.path.isdir(model_args.model_name_or_path):
790
+ checkpoint = model_args.model_name_or_path
791
+ else:
792
+ checkpoint = None
793
+
794
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
795
+ trainer.save_model()
796
+
797
+ metrics = train_result.metrics
798
+ max_train_samples = (
799
+ data_args.max_train_samples
800
+ if data_args.max_train_samples is not None
801
+ else len(vectorized_datasets["train"])
802
+ )
803
+ metrics["train_samples"] = min(
804
+ max_train_samples, len(vectorized_datasets["train"])
805
+ )
806
+
807
+ trainer.log_metrics("train", metrics)
808
+ trainer.save_metrics("train", metrics)
809
+ trainer.save_state()
810
+
811
+ # Evaluation
812
+ results = {}
813
+ if training_args.do_eval:
814
+ logger.info("*** Evaluate ***")
815
+ metrics = trainer.evaluate()
816
+ max_eval_samples = (
817
+ data_args.max_eval_samples
818
+ if data_args.max_eval_samples is not None
819
+ else len(vectorized_datasets["test"])
820
+ )
821
+ metrics["eval_samples"] = min(
822
+ max_eval_samples, len(vectorized_datasets["test"])
823
+ )
824
+
825
+ trainer.log_metrics("test", metrics)
826
+ trainer.save_metrics("test", metrics)
827
+
828
+ # Write model card and (optionally) push to hub
829
+ config_name = (
830
+ data_args.dataset_config_name
831
+ if data_args.dataset_config_name is not None
832
+ else "na"
833
+ )
834
+ kwargs = {
835
+ "finetuned_from": model_args.model_name_or_path,
836
+ "tasks": "automatic-speech-recognition",
837
+ "tags": ["automatic-speech-recognition", data_args.dataset_name],
838
+ "dataset_args": (
839
+ f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split:"
840
+ f" {data_args.eval_split_name}"
841
+ ),
842
+ "dataset": f"{data_args.dataset_name.upper()} - {config_name.upper()}",
843
+ }
844
+ if "common_voice" in data_args.dataset_name:
845
+ kwargs["language"] = config_name
846
+
847
+ if training_args.push_to_hub:
848
+ trainer.push_to_hub(**kwargs)
849
+ else:
850
+ trainer.create_model_card(**kwargs)
851
+
852
+ return results
853
+
854
+
855
+ if __name__ == "__main__":
856
+ main()
runs/Jan09_00-52-51_bookbot-pt-2/1673227740.802469/events.out.tfevents.1673227740.bookbot-pt-2.11380.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20aa80f9730f7c36363b97b9caec4240908ec23edcbd227b7f63a990d349d2c0
3
+ size 129
runs/Jan09_00-52-51_bookbot-pt-2/events.out.tfevents.1673227740.bookbot-pt-2.11380.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e104f3a7aae76e3e406df1e908549beab2cf785f91c81343eb46c25099655186
3
+ size 130
runs/Jan09_00-52-51_bookbot-pt-2/events.out.tfevents.1673247932.bookbot-pt-2.11380.2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed25fc9e7b5efd3663ecfa809f39c37a83df79b503f987ed5713e30160e842c4
3
+ size 128
special_tokens_map.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ {
4
+ "content": "<s>",
5
+ "lstrip": false,
6
+ "normalized": true,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ },
10
+ {
11
+ "content": "</s>",
12
+ "lstrip": false,
13
+ "normalized": true,
14
+ "rstrip": false,
15
+ "single_word": false
16
+ }
17
+ ],
18
+ "bos_token": "<s>",
19
+ "eos_token": "</s>",
20
+ "pad_token": "[PAD]",
21
+ "unk_token": "[UNK]"
22
+ }
test_results.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 30.0,
3
+ "eval_cer": 0.0058143100249896515,
4
+ "eval_loss": 0.06832413375377655,
5
+ "eval_runtime": 87.0708,
6
+ "eval_samples": 1965,
7
+ "eval_samples_per_second": 22.568,
8
+ "eval_steps_per_second": 2.825,
9
+ "eval_wer": 0.009874807524938073
10
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "do_lower_case": false,
4
+ "eos_token": "</s>",
5
+ "model_max_length": 1000000000000000019884624838656,
6
+ "name_or_path": "./wav2vec2-ljspeech-gruut",
7
+ "pad_token": "[PAD]",
8
+ "replace_word_delimiter_char": " ",
9
+ "special_tokens_map_file": null,
10
+ "tokenizer_class": "Wav2Vec2CTCTokenizer",
11
+ "unk_token": "[UNK]",
12
+ "word_delimiter_token": "|"
13
+ }
train_results.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 30.0,
3
+ "train_loss": 0.15939651108792915,
4
+ "train_runtime": 19958.548,
5
+ "train_samples": 11135,
6
+ "train_samples_per_second": 16.737,
7
+ "train_steps_per_second": 0.523
8
+ }
trainer_state.json ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": null,
3
+ "best_model_checkpoint": null,
4
+ "epoch": 30.0,
5
+ "global_step": 10440,
6
+ "is_hyper_param_search": false,
7
+ "is_local_process_zero": true,
8
+ "is_world_process_zero": true,
9
+ "log_history": [
10
+ {
11
+ "epoch": 1.0,
12
+ "eval_cer": 1.0,
13
+ "eval_loss": 2.2818498611450195,
14
+ "eval_runtime": 85.7153,
15
+ "eval_samples_per_second": 22.925,
16
+ "eval_steps_per_second": 2.87,
17
+ "eval_wer": 1.0,
18
+ "step": 348
19
+ },
20
+ {
21
+ "epoch": 1.44,
22
+ "learning_rate": 5e-05,
23
+ "loss": 2.6692,
24
+ "step": 500
25
+ },
26
+ {
27
+ "epoch": 2.0,
28
+ "eval_cer": 0.029949254143222897,
29
+ "eval_loss": 0.2044876664876938,
30
+ "eval_runtime": 85.7259,
31
+ "eval_samples_per_second": 22.922,
32
+ "eval_steps_per_second": 2.87,
33
+ "eval_wer": 0.052662850639351944,
34
+ "step": 696
35
+ },
36
+ {
37
+ "epoch": 2.87,
38
+ "learning_rate": 0.0001,
39
+ "loss": 0.2225,
40
+ "step": 1000
41
+ },
42
+ {
43
+ "epoch": 3.0,
44
+ "eval_cer": 0.018884051083140417,
45
+ "eval_loss": 0.11616221815347672,
46
+ "eval_runtime": 85.4863,
47
+ "eval_samples_per_second": 22.986,
48
+ "eval_steps_per_second": 2.878,
49
+ "eval_wer": 0.031858806989355296,
50
+ "step": 1044
51
+ },
52
+ {
53
+ "epoch": 4.0,
54
+ "eval_cer": 0.014687169423704908,
55
+ "eval_loss": 0.09268919378519058,
56
+ "eval_runtime": 85.2634,
57
+ "eval_samples_per_second": 23.046,
58
+ "eval_steps_per_second": 2.885,
59
+ "eval_wer": 0.023540536921737965,
60
+ "step": 1392
61
+ },
62
+ {
63
+ "epoch": 4.31,
64
+ "learning_rate": 9.470338983050848e-05,
65
+ "loss": 0.0868,
66
+ "step": 1500
67
+ },
68
+ {
69
+ "epoch": 5.0,
70
+ "eval_cer": 0.014284728716635749,
71
+ "eval_loss": 0.07971413433551788,
72
+ "eval_runtime": 85.4679,
73
+ "eval_samples_per_second": 22.991,
74
+ "eval_steps_per_second": 2.878,
75
+ "eval_wer": 0.021841735288210484,
76
+ "step": 1740
77
+ },
78
+ {
79
+ "epoch": 5.75,
80
+ "learning_rate": 8.940677966101694e-05,
81
+ "loss": 0.0598,
82
+ "step": 2000
83
+ },
84
+ {
85
+ "epoch": 6.0,
86
+ "eval_cer": 0.012847440477103041,
87
+ "eval_loss": 0.07152710855007172,
88
+ "eval_runtime": 85.5884,
89
+ "eval_samples_per_second": 22.959,
90
+ "eval_steps_per_second": 2.874,
91
+ "eval_wer": 0.01967429872129611,
92
+ "step": 2088
93
+ },
94
+ {
95
+ "epoch": 7.0,
96
+ "eval_cer": 0.010252656108666656,
97
+ "eval_loss": 0.06518065184354782,
98
+ "eval_runtime": 85.121,
99
+ "eval_samples_per_second": 23.085,
100
+ "eval_steps_per_second": 2.89,
101
+ "eval_wer": 0.015975430139921,
102
+ "step": 2436
103
+ },
104
+ {
105
+ "epoch": 7.18,
106
+ "learning_rate": 8.411016949152542e-05,
107
+ "loss": 0.0447,
108
+ "step": 2500
109
+ },
110
+ {
111
+ "epoch": 8.0,
112
+ "eval_cer": 0.00946693853772211,
113
+ "eval_loss": 0.057057663798332214,
114
+ "eval_runtime": 84.3173,
115
+ "eval_samples_per_second": 23.305,
116
+ "eval_steps_per_second": 2.918,
117
+ "eval_wer": 0.015230635335073977,
118
+ "step": 2784
119
+ },
120
+ {
121
+ "epoch": 8.62,
122
+ "learning_rate": 7.88135593220339e-05,
123
+ "loss": 0.0368,
124
+ "step": 3000
125
+ },
126
+ {
127
+ "epoch": 9.0,
128
+ "eval_cer": 0.011180186119245098,
129
+ "eval_loss": 0.060811206698417664,
130
+ "eval_runtime": 84.3248,
131
+ "eval_samples_per_second": 23.303,
132
+ "eval_steps_per_second": 2.917,
133
+ "eval_wer": 0.01630180089710116,
134
+ "step": 3132
135
+ },
136
+ {
137
+ "epoch": 10.0,
138
+ "eval_cer": 0.008297944102902173,
139
+ "eval_loss": 0.058583296835422516,
140
+ "eval_runtime": 84.6721,
141
+ "eval_samples_per_second": 23.207,
142
+ "eval_steps_per_second": 2.905,
143
+ "eval_wer": 0.013657360915846555,
144
+ "step": 3480
145
+ },
146
+ {
147
+ "epoch": 10.06,
148
+ "learning_rate": 7.351694915254238e-05,
149
+ "loss": 0.0303,
150
+ "step": 3500
151
+ },
152
+ {
153
+ "epoch": 11.0,
154
+ "eval_cer": 0.008535575758504913,
155
+ "eval_loss": 0.06412886828184128,
156
+ "eval_runtime": 87.3267,
157
+ "eval_samples_per_second": 22.502,
158
+ "eval_steps_per_second": 2.817,
159
+ "eval_wer": 0.014125995849233446,
160
+ "step": 3828
161
+ },
162
+ {
163
+ "epoch": 11.49,
164
+ "learning_rate": 6.822033898305085e-05,
165
+ "loss": 0.0273,
166
+ "step": 4000
167
+ },
168
+ {
169
+ "epoch": 12.0,
170
+ "eval_cer": 0.007933831082220552,
171
+ "eval_loss": 0.06564020365476608,
172
+ "eval_runtime": 84.2667,
173
+ "eval_samples_per_second": 23.319,
174
+ "eval_steps_per_second": 2.919,
175
+ "eval_wer": 0.013071567249112941,
176
+ "step": 4176
177
+ },
178
+ {
179
+ "epoch": 12.93,
180
+ "learning_rate": 6.29343220338983e-05,
181
+ "loss": 0.0232,
182
+ "step": 4500
183
+ },
184
+ {
185
+ "epoch": 13.0,
186
+ "eval_cer": 0.008225121498765848,
187
+ "eval_loss": 0.06898853182792664,
188
+ "eval_runtime": 84.1359,
189
+ "eval_samples_per_second": 23.355,
190
+ "eval_steps_per_second": 2.924,
191
+ "eval_wer": 0.0132640423110397,
192
+ "step": 4524
193
+ },
194
+ {
195
+ "epoch": 14.0,
196
+ "eval_cer": 0.00787250678400049,
197
+ "eval_loss": 0.05983823910355568,
198
+ "eval_runtime": 84.1873,
199
+ "eval_samples_per_second": 23.341,
200
+ "eval_steps_per_second": 2.922,
201
+ "eval_wer": 0.012803775858606146,
202
+ "step": 4872
203
+ },
204
+ {
205
+ "epoch": 14.37,
206
+ "learning_rate": 5.763771186440679e-05,
207
+ "loss": 0.0189,
208
+ "step": 5000
209
+ },
210
+ {
211
+ "epoch": 15.0,
212
+ "eval_cer": 0.007420240084627531,
213
+ "eval_loss": 0.06711488217115402,
214
+ "eval_runtime": 84.8039,
215
+ "eval_samples_per_second": 23.171,
216
+ "eval_steps_per_second": 2.901,
217
+ "eval_wer": 0.012100823458525808,
218
+ "step": 5220
219
+ },
220
+ {
221
+ "epoch": 15.8,
222
+ "learning_rate": 5.2341101694915265e-05,
223
+ "loss": 0.017,
224
+ "step": 5500
225
+ },
226
+ {
227
+ "epoch": 16.0,
228
+ "eval_cer": 0.006906649087034511,
229
+ "eval_loss": 0.06541039049625397,
230
+ "eval_runtime": 84.2563,
231
+ "eval_samples_per_second": 23.322,
232
+ "eval_steps_per_second": 2.92,
233
+ "eval_wer": 0.011364397134632121,
234
+ "step": 5568
235
+ },
236
+ {
237
+ "epoch": 17.0,
238
+ "eval_cer": 0.007335919174574946,
239
+ "eval_loss": 0.07511687278747559,
240
+ "eval_runtime": 84.0992,
241
+ "eval_samples_per_second": 23.365,
242
+ "eval_steps_per_second": 2.925,
243
+ "eval_wer": 0.011807926625159,
244
+ "step": 5916
245
+ },
246
+ {
247
+ "epoch": 17.24,
248
+ "learning_rate": 4.705508474576271e-05,
249
+ "loss": 0.0146,
250
+ "step": 6000
251
+ },
252
+ {
253
+ "epoch": 18.0,
254
+ "eval_cer": 0.006753338341484355,
255
+ "eval_loss": 0.06527850776910782,
256
+ "eval_runtime": 83.7988,
257
+ "eval_samples_per_second": 23.449,
258
+ "eval_steps_per_second": 2.936,
259
+ "eval_wer": 0.011171922072705363,
260
+ "step": 6264
261
+ },
262
+ {
263
+ "epoch": 18.68,
264
+ "learning_rate": 4.175847457627119e-05,
265
+ "loss": 0.0127,
266
+ "step": 6500
267
+ },
268
+ {
269
+ "epoch": 19.0,
270
+ "eval_cer": 0.006921980161589526,
271
+ "eval_loss": 0.06817645579576492,
272
+ "eval_runtime": 84.0515,
273
+ "eval_samples_per_second": 23.379,
274
+ "eval_steps_per_second": 2.927,
275
+ "eval_wer": 0.01123886992033206,
276
+ "step": 6612
277
+ },
278
+ {
279
+ "epoch": 20.0,
280
+ "eval_cer": 0.006814662639704417,
281
+ "eval_loss": 0.06784532964229584,
282
+ "eval_runtime": 83.9653,
283
+ "eval_samples_per_second": 23.403,
284
+ "eval_steps_per_second": 2.93,
285
+ "eval_wer": 0.01137276561558546,
286
+ "step": 6960
287
+ },
288
+ {
289
+ "epoch": 20.11,
290
+ "learning_rate": 3.6461864406779664e-05,
291
+ "loss": 0.0114,
292
+ "step": 7000
293
+ },
294
+ {
295
+ "epoch": 21.0,
296
+ "eval_cer": 0.006584696521379184,
297
+ "eval_loss": 0.06555593758821487,
298
+ "eval_runtime": 83.8204,
299
+ "eval_samples_per_second": 23.443,
300
+ "eval_steps_per_second": 2.935,
301
+ "eval_wer": 0.011113342706032,
302
+ "step": 7308
303
+ },
304
+ {
305
+ "epoch": 21.55,
306
+ "learning_rate": 3.117584745762712e-05,
307
+ "loss": 0.0101,
308
+ "step": 7500
309
+ },
310
+ {
311
+ "epoch": 22.0,
312
+ "eval_cer": 0.006596194827295445,
313
+ "eval_loss": 0.06685744225978851,
314
+ "eval_runtime": 84.0101,
315
+ "eval_samples_per_second": 23.39,
316
+ "eval_steps_per_second": 2.928,
317
+ "eval_wer": 0.010920867644105242,
318
+ "step": 7656
319
+ },
320
+ {
321
+ "epoch": 22.99,
322
+ "learning_rate": 2.5879237288135593e-05,
323
+ "loss": 0.0092,
324
+ "step": 8000
325
+ },
326
+ {
327
+ "epoch": 23.0,
328
+ "eval_cer": 0.006477378999494075,
329
+ "eval_loss": 0.06765928864479065,
330
+ "eval_runtime": 84.2885,
331
+ "eval_samples_per_second": 23.313,
332
+ "eval_steps_per_second": 2.919,
333
+ "eval_wer": 0.010778603467898508,
334
+ "step": 8004
335
+ },
336
+ {
337
+ "epoch": 24.0,
338
+ "eval_cer": 0.006331733791221427,
339
+ "eval_loss": 0.0652570053935051,
340
+ "eval_runtime": 84.1568,
341
+ "eval_samples_per_second": 23.349,
342
+ "eval_steps_per_second": 2.923,
343
+ "eval_wer": 0.010402021824998326,
344
+ "step": 8352
345
+ },
346
+ {
347
+ "epoch": 24.43,
348
+ "learning_rate": 2.058262711864407e-05,
349
+ "loss": 0.0088,
350
+ "step": 8500
351
+ },
352
+ {
353
+ "epoch": 25.0,
354
+ "eval_cer": 0.006266576724362611,
355
+ "eval_loss": 0.0673212930560112,
356
+ "eval_runtime": 83.9435,
357
+ "eval_samples_per_second": 23.409,
358
+ "eval_steps_per_second": 2.931,
359
+ "eval_wer": 0.01020117828211823,
360
+ "step": 8700
361
+ },
362
+ {
363
+ "epoch": 25.86,
364
+ "learning_rate": 1.5286016949152543e-05,
365
+ "loss": 0.0074,
366
+ "step": 9000
367
+ },
368
+ {
369
+ "epoch": 26.0,
370
+ "eval_cer": 0.006350897634415196,
371
+ "eval_loss": 0.06691750884056091,
372
+ "eval_runtime": 84.013,
373
+ "eval_samples_per_second": 23.389,
374
+ "eval_steps_per_second": 2.928,
375
+ "eval_wer": 0.0104857066345317,
376
+ "step": 9048
377
+ },
378
+ {
379
+ "epoch": 27.0,
380
+ "eval_cer": 0.006113265978812455,
381
+ "eval_loss": 0.0707407295703888,
382
+ "eval_runtime": 84.4435,
383
+ "eval_samples_per_second": 23.27,
384
+ "eval_steps_per_second": 2.913,
385
+ "eval_wer": 0.01013423043449153,
386
+ "step": 9396
387
+ },
388
+ {
389
+ "epoch": 27.3,
390
+ "learning_rate": 9.989406779661017e-06,
391
+ "loss": 0.0066,
392
+ "step": 9500
393
+ },
394
+ {
395
+ "epoch": 28.0,
396
+ "eval_cer": 0.0059829518450948225,
397
+ "eval_loss": 0.06726762652397156,
398
+ "eval_runtime": 84.279,
399
+ "eval_samples_per_second": 23.315,
400
+ "eval_steps_per_second": 2.919,
401
+ "eval_wer": 0.009966860815424784,
402
+ "step": 9744
403
+ },
404
+ {
405
+ "epoch": 28.74,
406
+ "learning_rate": 4.692796610169492e-06,
407
+ "loss": 0.0058,
408
+ "step": 10000
409
+ },
410
+ {
411
+ "epoch": 29.0,
412
+ "eval_cer": 0.005867968785932206,
413
+ "eval_loss": 0.06885003298521042,
414
+ "eval_runtime": 84.0405,
415
+ "eval_samples_per_second": 23.382,
416
+ "eval_steps_per_second": 2.927,
417
+ "eval_wer": 0.010000334739238134,
418
+ "step": 10092
419
+ },
420
+ {
421
+ "epoch": 30.0,
422
+ "eval_cer": 0.0058143100249896515,
423
+ "eval_loss": 0.06832413375377655,
424
+ "eval_runtime": 84.1565,
425
+ "eval_samples_per_second": 23.349,
426
+ "eval_steps_per_second": 2.923,
427
+ "eval_wer": 0.009874807524938073,
428
+ "step": 10440
429
+ },
430
+ {
431
+ "epoch": 30.0,
432
+ "step": 10440,
433
+ "total_flos": 2.004174730615405e+19,
434
+ "train_loss": 0.15939651108792915,
435
+ "train_runtime": 19958.548,
436
+ "train_samples_per_second": 16.737,
437
+ "train_steps_per_second": 0.523
438
+ }
439
+ ],
440
+ "max_steps": 10440,
441
+ "num_train_epochs": 30,
442
+ "total_flos": 2.004174730615405e+19,
443
+ "trial_name": null,
444
+ "trial_params": null
445
+ }
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de95cb920d4b0059afc57ae8c102afd51b2c3c85288d1d6f137d0d604bc5c375
3
+ size 129
vocab.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "[PAD]": 42,
3
+ "[UNK]": 41,
4
+ "aɪ": 1,
5
+ "aʊ": 2,
6
+ "b": 3,
7
+ "d": 4,
8
+ "d͡ʒ": 5,
9
+ "eɪ": 6,
10
+ "f": 7,
11
+ "h": 8,
12
+ "i": 9,
13
+ "j": 10,
14
+ "k": 11,
15
+ "l": 12,
16
+ "m": 13,
17
+ "n": 14,
18
+ "oʊ": 15,
19
+ "p": 16,
20
+ "s": 17,
21
+ "t": 18,
22
+ "t͡ʃ": 19,
23
+ "u": 20,
24
+ "v": 21,
25
+ "w": 22,
26
+ "z": 23,
27
+ "|": 0,
28
+ "æ": 24,
29
+ "ð": 25,
30
+ "ŋ": 26,
31
+ "ɑ": 27,
32
+ "ɔ": 28,
33
+ "ɔɪ": 29,
34
+ "ə": 30,
35
+ "ɚ": 31,
36
+ "ɛ": 32,
37
+ "ɡ": 33,
38
+ "ɪ": 34,
39
+ "ɹ": 35,
40
+ "ʃ": 36,
41
+ "ʊ": 37,
42
+ "ʌ": 38,
43
+ "ʒ": 39,
44
+ "θ": 40
45
+ }