Yehor Smoliakov commited on
Commit
f8d394f
1 Parent(s): a171ea2
.gitattributes CHANGED
@@ -1,6 +1,7 @@
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
  *.ftz filter=lfs diff=lfs merge=lfs -text
6
  *.gz filter=lfs diff=lfs merge=lfs -text
@@ -16,12 +17,13 @@
16
  *.pt filter=lfs diff=lfs merge=lfs -text
17
  *.pth filter=lfs diff=lfs merge=lfs -text
18
  *.rar filter=lfs diff=lfs merge=lfs -text
19
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
20
  *.tar.* filter=lfs diff=lfs merge=lfs -text
21
  *.tflite filter=lfs diff=lfs merge=lfs -text
22
  *.tgz filter=lfs diff=lfs merge=lfs -text
23
- *.wasm filter=lfs diff=lfs merge=lfs -text
24
  *.xz filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
 
 
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
  *.bz2 filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
17
  *.pt filter=lfs diff=lfs merge=lfs -text
18
  *.pth filter=lfs diff=lfs merge=lfs -text
19
  *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
  *.tar.* filter=lfs diff=lfs merge=lfs -text
22
  *.tflite filter=lfs diff=lfs merge=lfs -text
23
  *.tgz filter=lfs diff=lfs merge=lfs -text
 
24
  *.xz filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *.arpa filter=lfs diff=lfs merge=lfs -text
28
+ *.txt filter=lfs diff=lfs merge=lfs -text
29
  *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,64 @@
1
  ---
 
 
2
  license: apache-2.0
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ language:
3
+ - uk
4
  license: apache-2.0
5
+ tags:
6
+ - automatic-speech-recognition
7
+ - mozilla-foundation/common_voice_7_0
8
+ - generated_from_trainer
9
+ - uk
10
+ xdatasets:
11
+ - mozilla-foundation/common_voice_7_0
12
  ---
13
+
14
+ # Ukrainian STT model (with the Big Language Model formed on News Dataset)
15
+
16
+ This model is a fine-tuned version of [facebook/wav2vec2-xls-r-1b](https://huggingface.co/facebook/wav2vec2-xls-r-1b) on the MOZILLA-FOUNDATION/COMMON_VOICE_7_0 - UK dataset.
17
+
18
+ Follow our community in Telegram: https://t.me/speech_recognition_uk
19
+
20
+ Attribution to the dataset of Language Model:
21
+
22
+ - Chaplynskyi, D. et al. (2021) lang-uk Ukrainian Ubercorpus [Data set]. https://lang.org.ua/uk/corpora/#anchor4
23
+
24
+ ## Training procedure
25
+
26
+ ### Training hyperparameters
27
+
28
+ The following hyperparameters were used during training:
29
+ - learning_rate: 5e-05
30
+ - train_batch_size: 8
31
+ - eval_batch_size: 8
32
+ - seed: 42
33
+ - gradient_accumulation_steps: 20
34
+ - total_train_batch_size: 160
35
+ - optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
36
+ - lr_scheduler_type: linear
37
+ - lr_scheduler_warmup_steps: 500
38
+ - num_epochs: 100.0
39
+ - mixed_precision_training: Native AMP
40
+
41
+ ### Training results
42
+
43
+ | Training Loss | Epoch | Step | Validation Loss | Wer | Cer |
44
+ |:-------------:|:-----:|:----:|:---------------:|:------:|:------:|
45
+ | 1.2815 | 7.93 | 500 | 0.3536 | 0.4753 | 0.1009 |
46
+ | 1.0869 | 15.86 | 1000 | 0.2317 | 0.3111 | 0.0614 |
47
+ | 0.9984 | 23.8 | 1500 | 0.2022 | 0.2676 | 0.0521 |
48
+ | 0.975 | 31.74 | 2000 | 0.1948 | 0.2469 | 0.0487 |
49
+ | 0.9306 | 39.67 | 2500 | 0.1916 | 0.2377 | 0.0464 |
50
+ | 0.8868 | 47.61 | 3000 | 0.1903 | 0.2257 | 0.0439 |
51
+ | 0.8424 | 55.55 | 3500 | 0.1786 | 0.2206 | 0.0423 |
52
+ | 0.8126 | 63.49 | 4000 | 0.1849 | 0.2160 | 0.0416 |
53
+ | 0.7901 | 71.42 | 4500 | 0.1869 | 0.2138 | 0.0413 |
54
+ | 0.7671 | 79.36 | 5000 | 0.1855 | 0.2075 | 0.0394 |
55
+ | 0.7467 | 87.3 | 5500 | 0.1884 | 0.2049 | 0.0389 |
56
+ | 0.731 | 95.24 | 6000 | 0.1877 | 0.2060 | 0.0387 |
57
+
58
+
59
+ ### Framework versions
60
+
61
+ - Transformers 4.16.0.dev0
62
+ - Pytorch 1.10.1+cu102
63
+ - Datasets 1.18.1.dev0
64
+ - Tokenizers 0.11.0
added_tokens.json ADDED
@@ -0,0 +1 @@
 
1
+ {"<s>": 51, "</s>": 52}
all_results.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 99.99,
3
+ "eval_cer": 0.03843434091927693,
4
+ "eval_loss": 0.18747110664844513,
5
+ "eval_runtime": 268.1183,
6
+ "eval_samples": 4332,
7
+ "eval_samples_per_second": 16.157,
8
+ "eval_steps_per_second": 2.021,
9
+ "eval_wer": 0.20326104163368688,
10
+ "train_loss": 1.049089940994505,
11
+ "train_runtime": 95054.1856,
12
+ "train_samples": 10193,
13
+ "train_samples_per_second": 10.723,
14
+ "train_steps_per_second": 0.066
15
+ }
alphabet.json ADDED
@@ -0,0 +1 @@
 
1
+ {"labels": [" ", "a", "c", "e", "i", "j", "k", "l", "m", "n", "o", "p", "u", "x", "y", "\u0301", "\u0430", "\u0431", "\u0432", "\u0433", "\u0434", "\u0435", "\u0436", "\u0437", "\u0438", "\u0439", "\u043a", "\u043b", "\u043c", "\u043d", "\u043e", "\u043f", "\u0440", "\u0441", "\u0442", "\u0443", "\u0444", "\u0445", "\u0446", "\u0447", "\u0448", "\u0449", "\u044c", "\u044e", "\u044f", "\u0454", "\u0456", "\u0457", "\u0491", "\u2047", "", "<s>", "</s>"], "is_bpe": false}
config.json ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "facebook/wav2vec2-xls-r-1b",
3
+ "activation_dropout": 0.1,
4
+ "adapter_kernel_size": 3,
5
+ "adapter_stride": 2,
6
+ "add_adapter": false,
7
+ "apply_spec_augment": true,
8
+ "architectures": [
9
+ "Wav2Vec2ForCTC"
10
+ ],
11
+ "attention_dropout": 0.0,
12
+ "bos_token_id": 1,
13
+ "classifier_proj_size": 256,
14
+ "codevector_dim": 1024,
15
+ "contrastive_logits_temperature": 0.1,
16
+ "conv_bias": true,
17
+ "conv_dim": [
18
+ 512,
19
+ 512,
20
+ 512,
21
+ 512,
22
+ 512,
23
+ 512,
24
+ 512
25
+ ],
26
+ "conv_kernel": [
27
+ 10,
28
+ 3,
29
+ 3,
30
+ 3,
31
+ 3,
32
+ 2,
33
+ 2
34
+ ],
35
+ "conv_stride": [
36
+ 5,
37
+ 2,
38
+ 2,
39
+ 2,
40
+ 2,
41
+ 2,
42
+ 2
43
+ ],
44
+ "ctc_loss_reduction": "mean",
45
+ "ctc_zero_infinity": false,
46
+ "diversity_loss_weight": 0.1,
47
+ "do_stable_layer_norm": true,
48
+ "eos_token_id": 2,
49
+ "feat_extract_activation": "gelu",
50
+ "feat_extract_dropout": 0.0,
51
+ "feat_extract_norm": "layer",
52
+ "feat_proj_dropout": 0.0,
53
+ "feat_quantizer_dropout": 0.0,
54
+ "final_dropout": 0.0,
55
+ "hidden_act": "gelu",
56
+ "hidden_dropout": 0.0,
57
+ "hidden_size": 1280,
58
+ "initializer_range": 0.02,
59
+ "intermediate_size": 5120,
60
+ "layer_norm_eps": 1e-05,
61
+ "layerdrop": 0.0,
62
+ "mask_feature_length": 64,
63
+ "mask_feature_min_masks": 0,
64
+ "mask_feature_prob": 0.25,
65
+ "mask_time_length": 10,
66
+ "mask_time_min_masks": 2,
67
+ "mask_time_prob": 0.75,
68
+ "model_type": "wav2vec2",
69
+ "num_adapter_layers": 3,
70
+ "num_attention_heads": 16,
71
+ "num_codevector_groups": 2,
72
+ "num_codevectors_per_group": 320,
73
+ "num_conv_pos_embedding_groups": 16,
74
+ "num_conv_pos_embeddings": 128,
75
+ "num_feat_extract_layers": 7,
76
+ "num_hidden_layers": 48,
77
+ "num_negatives": 100,
78
+ "output_hidden_size": 1280,
79
+ "pad_token_id": 50,
80
+ "proj_codevector_dim": 1024,
81
+ "tdnn_dilation": [
82
+ 1,
83
+ 2,
84
+ 3,
85
+ 1,
86
+ 1
87
+ ],
88
+ "tdnn_dim": [
89
+ 512,
90
+ 512,
91
+ 512,
92
+ 512,
93
+ 1500
94
+ ],
95
+ "tdnn_kernel": [
96
+ 5,
97
+ 3,
98
+ 3,
99
+ 1,
100
+ 1
101
+ ],
102
+ "torch_dtype": "float32",
103
+ "transformers_version": "4.16.0.dev0",
104
+ "use_weighted_layer_sum": false,
105
+ "vocab_size": 53,
106
+ "xvector_output_dim": 512
107
+ }
eval.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from datasets import load_dataset, load_metric, Audio, Dataset
3
+ from transformers import pipeline, AutoFeatureExtractor
4
+ import re
5
+ import argparse
6
+ import unicodedata
7
+ from typing import Dict
8
+
9
+
10
+ def log_results(result: Dataset, args: Dict[str, str]):
11
+ """ DO NOT CHANGE. This function computes and logs the result metrics. """
12
+
13
+ log_outputs = args.log_outputs
14
+ dataset_id = "_".join(args.dataset.split("/") + [args.config, args.split])
15
+
16
+ # load metric
17
+ wer = load_metric("wer")
18
+ cer = load_metric("cer")
19
+
20
+ # compute metrics
21
+ wer_result = wer.compute(references=result["target"], predictions=result["prediction"])
22
+ cer_result = cer.compute(references=result["target"], predictions=result["prediction"])
23
+
24
+ # print & log results
25
+ result_str = (
26
+ f"WER: {wer_result}\n"
27
+ f"CER: {cer_result}"
28
+ )
29
+ print(result_str)
30
+
31
+ with open(f"{dataset_id}_eval_results.txt", "w") as f:
32
+ f.write(result_str)
33
+
34
+ # log all results in text file. Possibly interesting for analysis
35
+ if log_outputs is not None:
36
+ pred_file = f"log_{dataset_id}_predictions.txt"
37
+ target_file = f"log_{dataset_id}_targets.txt"
38
+
39
+ with open(pred_file, "w") as p, open(target_file, "w") as t:
40
+
41
+ # mapping function to write output
42
+ def write_to_file(batch, i):
43
+ p.write(f"{i}" + "\n")
44
+ p.write(batch["prediction"] + "\n")
45
+ t.write(f"{i}" + "\n")
46
+ t.write(batch["target"] + "\n")
47
+
48
+ result.map(write_to_file, with_indices=True)
49
+
50
+
51
+ def normalize_text(text: str) -> str:
52
+ """ DO ADAPT FOR YOUR USE CASE. this function normalizes the target text. """
53
+
54
+ chars_to_ignore_regex = '[,?.!\-\;\:\"“%‘”�—’…–]' # noqa: W605 IMPORTANT: this should correspond to the chars that were ignored during training
55
+
56
+ text = text.lower()
57
+ # normalize non-standard (stylized) unicode characters
58
+ text = unicodedata.normalize('NFKC', text)
59
+ # remove punctuation
60
+ text = re.sub(chars_to_ignore_regex, "", text)
61
+
62
+ # Let's also make sure we split on all kinds of newlines, spaces, etc...
63
+ text = " ".join(text.split())
64
+
65
+ return text
66
+
67
+
68
+ def main(args):
69
+ # load dataset
70
+ dataset = load_dataset(args.dataset, args.config, split=args.split, use_auth_token=True)
71
+
72
+ # for testing: only process the first two examples as a test
73
+ # dataset = dataset.select(range(10))
74
+
75
+ # load processor
76
+ feature_extractor = AutoFeatureExtractor.from_pretrained(args.model_id)
77
+ sampling_rate = feature_extractor.sampling_rate
78
+
79
+ # resample audio
80
+ dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))
81
+
82
+ # load eval pipeline
83
+ asr = pipeline("automatic-speech-recognition", model=args.model_id)
84
+
85
+ # map function to decode audio
86
+ def map_to_pred(batch):
87
+ prediction = asr(batch["audio"]["array"], chunk_length_s=args.chunk_length_s, stride_length_s=args.stride_length_s)
88
+
89
+ batch["prediction"] = prediction["text"]
90
+ batch["target"] = normalize_text(batch["sentence"])
91
+ return batch
92
+
93
+ # run inference on all examples
94
+ result = dataset.map(map_to_pred, remove_columns=dataset.column_names)
95
+
96
+ # compute and log_results
97
+ # do not change function below
98
+ log_results(result, args)
99
+
100
+
101
+ if __name__ == "__main__":
102
+ parser = argparse.ArgumentParser()
103
+
104
+ parser.add_argument(
105
+ "--model_id", type=str, required=True, help="Model identifier. Should be loadable with 🤗 Transformers"
106
+ )
107
+ parser.add_argument(
108
+ "--dataset", type=str, required=True, help="Dataset name to evaluate the `model_id`. Should be loadable with 🤗 Datasets"
109
+ )
110
+ parser.add_argument(
111
+ "--config", type=str, required=True, help="Config of the dataset. *E.g.* `'en'` for Common Voice"
112
+ )
113
+ parser.add_argument(
114
+ "--split", type=str, required=True, help="Split of the dataset. *E.g.* `'test'`"
115
+ )
116
+ parser.add_argument(
117
+ "--chunk_length_s", type=float, default=None, help="Chunk length in seconds. Defaults to None. For long audio files a good value would be 5.0 seconds."
118
+ )
119
+ parser.add_argument(
120
+ "--stride_length_s", type=float, default=None, help="Stride of the audio chunks. Defaults to None. For long audio files a good value would be 1.0 seconds."
121
+ )
122
+ parser.add_argument(
123
+ "--log_outputs", action='store_true', help="If defined, write outputs to log file for analysis."
124
+ )
125
+ args = parser.parse_args()
126
+
127
+ main(args)
eval_results.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 99.99,
3
+ "eval_cer": 0.03843434091927693,
4
+ "eval_loss": 0.18747110664844513,
5
+ "eval_runtime": 268.1183,
6
+ "eval_samples": 4332,
7
+ "eval_samples_per_second": 16.157,
8
+ "eval_steps_per_second": 2.021,
9
+ "eval_wer": 0.20326104163368688
10
+ }
language_model/5gram.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:216af45a8218db11be16723d33bc458935d310e7d3852297e51af7447689ab3f
3
+ size 5899238255
language_model/attrs.json ADDED
@@ -0,0 +1 @@
 
1
+ {"alpha": 0.5, "beta": 1.5, "unk_score_offset": -10.0, "score_boundary": true}
language_model/unigrams.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b685f9bd31076e9c32363267e3808a6e17ffd36e90488b360a64f8ffb290c2a
3
+ size 64208069
preprocessor_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "feature_extractor_type": "Wav2Vec2FeatureExtractor",
4
+ "feature_size": 1,
5
+ "padding_side": "right",
6
+ "padding_value": 0,
7
+ "processor_class": "Wav2Vec2ProcessorWithLM",
8
+ "return_attention_mask": true,
9
+ "sampling_rate": 16000
10
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6933edff2d809c0035d6ce0fffe2bf975c5d0d3f2932f2c4545f47a769d04ce4
3
+ size 3850584305
run.sh ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ python run_speech_recognition_ctc.py \
4
+ --dataset_name="mozilla-foundation/common_voice_7_0" \
5
+ --model_name_or_path="facebook/wav2vec2-xls-r-1b" \
6
+ --dataset_config_name="uk" \
7
+ --output_dir="./cv-uk-ft" \
8
+ --num_train_epochs="100" \
9
+ --per_device_train_batch_size="8" \
10
+ --per_device_eval_batch_size="8" \
11
+ --gradient_accumulation_steps="20" \
12
+ --learning_rate="5e-5" \
13
+ --warmup_steps="500" \
14
+ --length_column_name="input_length" \
15
+ --evaluation_strategy="steps" \
16
+ --text_column_name="sentence" \
17
+ --save_steps="500" \
18
+ --eval_steps="500" \
19
+ --logging_steps="50" \
20
+ --layerdrop="0.0" \
21
+ --activation_dropout="0.1" \
22
+ --eval_metrics wer cer \
23
+ --save_total_limit="3" \
24
+ --feat_proj_dropout="0.0" \
25
+ --mask_time_prob="0.75" \
26
+ --mask_time_length="10" \
27
+ --mask_feature_prob="0.25" \
28
+ --mask_feature_length="64" \
29
+ --freeze_feature_encoder \
30
+ --chars_to_ignore « » ы — – ՚ ’ … , ? . ! - \; \: \" “ % ‘ \` ” � \
31
+ --fp16 \
32
+ --group_by_length \
33
+ --use_auth_token \
34
+ --report_to="tensorboard" \
35
+ --do_train --do_eval \
36
+ --gradient_checkpointing
run_speech_recognition_ctc.py ADDED
@@ -0,0 +1,737 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ """ Fine-tuning a 🤗 Transformers CTC model for automatic speech recognition"""
17
+
18
+ import functools
19
+ import json
20
+ import logging
21
+ import os
22
+ import re
23
+ import sys
24
+ import warnings
25
+ from dataclasses import dataclass, field
26
+ from typing import Dict, List, Optional, Union
27
+
28
+ import datasets
29
+ import numpy as np
30
+ import torch
31
+ from datasets import DatasetDict, load_dataset, load_metric
32
+
33
+ import transformers
34
+ from transformers import (
35
+ AutoConfig,
36
+ AutoFeatureExtractor,
37
+ AutoModelForCTC,
38
+ AutoProcessor,
39
+ AutoTokenizer,
40
+ HfArgumentParser,
41
+ Trainer,
42
+ TrainingArguments,
43
+ Wav2Vec2Processor,
44
+ set_seed,
45
+ )
46
+ from transformers.trainer_utils import get_last_checkpoint, is_main_process
47
+ from transformers.utils import check_min_version
48
+ from transformers.utils.versions import require_version
49
+
50
+
51
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
52
+ check_min_version("4.16.0.dev0")
53
+
54
+ require_version("datasets>=1.13.3", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
55
+
56
+
57
+ logger = logging.getLogger(__name__)
58
+
59
+
60
+ def list_field(default=None, metadata=None):
61
+ return field(default_factory=lambda: default, metadata=metadata)
62
+
63
+
64
+ @dataclass
65
+ class ModelArguments:
66
+ """
67
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
68
+ """
69
+
70
+ model_name_or_path: str = field(
71
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
72
+ )
73
+ tokenizer_name_or_path: Optional[str] = field(
74
+ default=None,
75
+ metadata={"help": "Path to pretrained tokenizer or tokenizer identifier from huggingface.co/models"},
76
+ )
77
+ cache_dir: Optional[str] = field(
78
+ default=None,
79
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
80
+ )
81
+ freeze_feature_encoder: bool = field(
82
+ default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
83
+ )
84
+ attention_dropout: float = field(
85
+ default=0.0, metadata={"help": "The dropout ratio for the attention probabilities."}
86
+ )
87
+ activation_dropout: float = field(
88
+ default=0.0, metadata={"help": "The dropout ratio for activations inside the fully connected layer."}
89
+ )
90
+ feat_proj_dropout: float = field(default=0.0, metadata={"help": "The dropout ratio for the projected features."})
91
+ hidden_dropout: float = field(
92
+ default=0.0,
93
+ metadata={
94
+ "help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler."
95
+ },
96
+ )
97
+ final_dropout: float = field(
98
+ default=0.0,
99
+ metadata={"help": "The dropout probability for the final projection layer."},
100
+ )
101
+ mask_time_prob: float = field(
102
+ default=0.05,
103
+ metadata={
104
+ "help": "Probability of each feature vector along the time axis to be chosen as the start of the vector"
105
+ "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
106
+ "vectors will be masked along the time axis."
107
+ },
108
+ )
109
+ mask_time_length: int = field(
110
+ default=10,
111
+ metadata={"help": "Length of vector span to mask along the time axis."},
112
+ )
113
+ mask_feature_prob: float = field(
114
+ default=0.0,
115
+ metadata={
116
+ "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector"
117
+ "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis."
118
+ },
119
+ )
120
+ mask_feature_length: int = field(
121
+ default=10,
122
+ metadata={"help": "Length of vector span to mask along the feature axis."},
123
+ )
124
+ layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."})
125
+ ctc_loss_reduction: Optional[str] = field(
126
+ default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
127
+ )
128
+
129
+
130
+ @dataclass
131
+ class DataTrainingArguments:
132
+ """
133
+ Arguments pertaining to what data we are going to input our model for training and eval.
134
+
135
+ Using `HfArgumentParser` we can turn this class
136
+ into argparse arguments to be able to specify them on
137
+ the command line.
138
+ """
139
+
140
+ dataset_name: str = field(
141
+ metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
142
+ )
143
+ dataset_config_name: str = field(
144
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
145
+ )
146
+ train_split_name: str = field(
147
+ default="train+validation",
148
+ metadata={
149
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
150
+ },
151
+ )
152
+ eval_split_name: str = field(
153
+ default="test",
154
+ metadata={
155
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
156
+ },
157
+ )
158
+ audio_column_name: str = field(
159
+ default="audio",
160
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
161
+ )
162
+ text_column_name: str = field(
163
+ default="text",
164
+ metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
165
+ )
166
+ overwrite_cache: bool = field(
167
+ default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
168
+ )
169
+ preprocessing_num_workers: Optional[int] = field(
170
+ default=None,
171
+ metadata={"help": "The number of processes to use for the preprocessing."},
172
+ )
173
+ max_train_samples: Optional[int] = field(
174
+ default=None,
175
+ metadata={
176
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
177
+ "value if set."
178
+ },
179
+ )
180
+ max_eval_samples: Optional[int] = field(
181
+ default=None,
182
+ metadata={
183
+ "help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
184
+ "value if set."
185
+ },
186
+ )
187
+ chars_to_ignore: Optional[List[str]] = list_field(
188
+ default=None,
189
+ metadata={"help": "A list of characters to remove from the transcripts."},
190
+ )
191
+ eval_metrics: List[str] = list_field(
192
+ default=["wer"],
193
+ metadata={"help": "A list of metrics the model should be evaluated on. E.g. `'wer cer'`"},
194
+ )
195
+ max_duration_in_seconds: float = field(
196
+ default=20.0,
197
+ metadata={
198
+ "help": "Filter audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`"
199
+ },
200
+ )
201
+ min_duration_in_seconds: float = field(
202
+ default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
203
+ )
204
+ preprocessing_only: bool = field(
205
+ default=False,
206
+ metadata={
207
+ "help": "Whether to only do data preprocessing and skip training. "
208
+ "This is especially useful when data preprocessing errors out in distributed training due to timeout. "
209
+ "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
210
+ "so that the cached datasets can consequently be loaded in distributed training"
211
+ },
212
+ )
213
+ use_auth_token: bool = field(
214
+ default=False,
215
+ metadata={
216
+ "help": "If :obj:`True`, will use the token generated when running"
217
+ ":obj:`transformers-cli login` as HTTP bearer authorization for remote files."
218
+ },
219
+ )
220
+ unk_token: str = field(
221
+ default="[UNK]",
222
+ metadata={"help": "The unk token for the tokenizer"},
223
+ )
224
+ pad_token: str = field(
225
+ default="[PAD]",
226
+ metadata={"help": "The padding token for the tokenizer"},
227
+ )
228
+ word_delimiter_token: str = field(
229
+ default="|",
230
+ metadata={"help": "The word delimiter token for the tokenizer"},
231
+ )
232
+ phoneme_language: Optional[str] = field(
233
+ default=None,
234
+ metadata={
235
+ "help": "The target language that should be used be"
236
+ " passed to the tokenizer for tokenization. Note that"
237
+ " this is only relevant if the model classifies the"
238
+ " input audio to a sequence of phoneme sequences."
239
+ },
240
+ )
241
+
242
+
243
+ @dataclass
244
+ class DataCollatorCTCWithPadding:
245
+ """
246
+ Data collator that will dynamically pad the inputs received.
247
+ Args:
248
+ processor (:class:`~transformers.AutoProcessor`)
249
+ The processor used for proccessing the data.
250
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
251
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
252
+ among:
253
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
254
+ sequence if provided).
255
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
256
+ maximum acceptable input length for the model if that argument is not provided.
257
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
258
+ different lengths).
259
+ max_length (:obj:`int`, `optional`):
260
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
261
+ max_length_labels (:obj:`int`, `optional`):
262
+ Maximum length of the ``labels`` returned list and optionally padding length (see above).
263
+ pad_to_multiple_of (:obj:`int`, `optional`):
264
+ If set will pad the sequence to a multiple of the provided value.
265
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
266
+ 7.5 (Volta).
267
+ """
268
+
269
+ processor: AutoProcessor
270
+ padding: Union[bool, str] = "longest"
271
+ pad_to_multiple_of: Optional[int] = None
272
+ pad_to_multiple_of_labels: Optional[int] = None
273
+
274
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
275
+ # split inputs and labels since they have to be of different lenghts and need
276
+ # different padding methods
277
+ input_features = [{"input_values": feature["input_values"]} for feature in features]
278
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
279
+
280
+ batch = self.processor.pad(
281
+ input_features,
282
+ padding=self.padding,
283
+ pad_to_multiple_of=self.pad_to_multiple_of,
284
+ return_tensors="pt",
285
+ )
286
+
287
+ with self.processor.as_target_processor():
288
+ labels_batch = self.processor.pad(
289
+ label_features,
290
+ padding=self.padding,
291
+ pad_to_multiple_of=self.pad_to_multiple_of_labels,
292
+ return_tensors="pt",
293
+ )
294
+
295
+ # replace padding with -100 to ignore loss correctly
296
+ labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
297
+
298
+ batch["labels"] = labels
299
+
300
+ return batch
301
+
302
+
303
+ def create_vocabulary_from_data(
304
+ datasets: DatasetDict,
305
+ word_delimiter_token: Optional[str] = None,
306
+ unk_token: Optional[str] = None,
307
+ pad_token: Optional[str] = None,
308
+ ):
309
+ # Given training and test labels create vocabulary
310
+ def extract_all_chars(batch):
311
+ all_text = " ".join(batch["target_text"])
312
+ vocab = list(set(all_text))
313
+ return {"vocab": [vocab], "all_text": [all_text]}
314
+
315
+ vocabs = datasets.map(
316
+ extract_all_chars,
317
+ batched=True,
318
+ batch_size=-1,
319
+ keep_in_memory=True,
320
+ remove_columns=datasets["train"].column_names,
321
+ )
322
+
323
+ # take union of all unique characters in each dataset
324
+ vocab_set = functools.reduce(
325
+ lambda vocab_1, vocab_2: set(vocab_1["vocab"][0]) | set(vocab_2["vocab"][0]), vocabs.values()
326
+ )
327
+
328
+ vocab_dict = {v: k for k, v in enumerate(sorted(list(vocab_set)))}
329
+
330
+ # replace white space with delimiter token
331
+ if word_delimiter_token is not None:
332
+ vocab_dict[word_delimiter_token] = vocab_dict[" "]
333
+ del vocab_dict[" "]
334
+
335
+ # add unk and pad token
336
+ if unk_token is not None:
337
+ vocab_dict[unk_token] = len(vocab_dict)
338
+
339
+ if pad_token is not None:
340
+ vocab_dict[pad_token] = len(vocab_dict)
341
+
342
+ return vocab_dict
343
+
344
+
345
+ def main():
346
+ # See all possible arguments in src/transformers/training_args.py
347
+ # or by passing the --help flag to this script.
348
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
349
+
350
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
351
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
352
+ # If we pass only one argument to the script and it's the path to a json file,
353
+ # let's parse it to get our arguments.
354
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
355
+ else:
356
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
357
+
358
+ # Detecting last checkpoint.
359
+ last_checkpoint = None
360
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
361
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
362
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
363
+ raise ValueError(
364
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
365
+ "Use --overwrite_output_dir to overcome."
366
+ )
367
+ elif last_checkpoint is not None:
368
+ logger.info(
369
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
370
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
371
+ )
372
+
373
+ # Setup logging
374
+ logging.basicConfig(
375
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
376
+ datefmt="%m/%d/%Y %H:%M:%S",
377
+ handlers=[logging.StreamHandler(sys.stdout)],
378
+ )
379
+ logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
380
+
381
+ # Log on each process the small summary:
382
+ logger.warning(
383
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
384
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
385
+ )
386
+ # Set the verbosity to info of the Transformers logger (on main process only):
387
+ if is_main_process(training_args.local_rank):
388
+ transformers.utils.logging.set_verbosity_info()
389
+ logger.info("Training/evaluation parameters %s", training_args)
390
+
391
+ # Set seed before initializing model.
392
+ set_seed(training_args.seed)
393
+
394
+ # 1. First, let's load the dataset
395
+ raw_datasets = DatasetDict()
396
+
397
+ if training_args.do_train:
398
+ raw_datasets["train"] = load_dataset(
399
+ data_args.dataset_name,
400
+ data_args.dataset_config_name,
401
+ split=data_args.train_split_name,
402
+ use_auth_token=data_args.use_auth_token,
403
+ )
404
+
405
+ if data_args.audio_column_name not in raw_datasets["train"].column_names:
406
+ raise ValueError(
407
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
408
+ "Make sure to set `--audio_column_name` to the correct audio column - one of "
409
+ f"{', '.join(raw_datasets['train'].column_names)}."
410
+ )
411
+
412
+ if data_args.text_column_name not in raw_datasets["train"].column_names:
413
+ raise ValueError(
414
+ f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
415
+ "Make sure to set `--text_column_name` to the correct text column - one of "
416
+ f"{', '.join(raw_datasets['train'].column_names)}."
417
+ )
418
+
419
+ if data_args.max_train_samples is not None:
420
+ raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
421
+
422
+ if training_args.do_eval:
423
+ raw_datasets["eval"] = load_dataset(
424
+ data_args.dataset_name,
425
+ data_args.dataset_config_name,
426
+ split=data_args.eval_split_name,
427
+ use_auth_token=data_args.use_auth_token,
428
+ )
429
+
430
+ if data_args.max_eval_samples is not None:
431
+ raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
432
+
433
+ # 2. We remove some special characters from the datasets
434
+ # that make training complicated and do not help in transcribing the speech
435
+ # E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
436
+ # that could be easily picked up by the model
437
+ chars_to_ignore_regex = (
438
+ f'[{"".join(data_args.chars_to_ignore)}]' if data_args.chars_to_ignore is not None else None
439
+ )
440
+ text_column_name = data_args.text_column_name
441
+
442
+ def remove_special_characters(batch):
443
+ if chars_to_ignore_regex is not None:
444
+ batch["target_text"] = re.sub(chars_to_ignore_regex, "", batch[text_column_name]).lower() + " "
445
+ else:
446
+ batch["target_text"] = batch[text_column_name].lower() + " "
447
+ return batch
448
+
449
+ with training_args.main_process_first(desc="dataset map special characters removal"):
450
+ raw_datasets = raw_datasets.map(
451
+ remove_special_characters,
452
+ remove_columns=[text_column_name],
453
+ desc="remove special characters from datasets",
454
+ )
455
+
456
+ # save special tokens for tokenizer
457
+ word_delimiter_token = data_args.word_delimiter_token
458
+ unk_token = data_args.unk_token
459
+ pad_token = data_args.pad_token
460
+
461
+ # 3. Next, let's load the config as we might need it to create
462
+ # the tokenizer
463
+ # load config
464
+ config = AutoConfig.from_pretrained(
465
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=data_args.use_auth_token
466
+ )
467
+
468
+ # 4. Next, if no tokenizer file is defined,
469
+ # we create the vocabulary of the model by extracting all unique characters from
470
+ # the training and evaluation datasets
471
+ # We need to make sure that only first rank saves vocabulary
472
+ # make sure all processes wait until vocab is created
473
+ tokenizer_name_or_path = model_args.tokenizer_name_or_path
474
+ tokenizer_kwargs = {}
475
+ if tokenizer_name_or_path is None:
476
+ # save vocab in training output dir
477
+ tokenizer_name_or_path = training_args.output_dir
478
+
479
+ vocab_file = os.path.join(tokenizer_name_or_path, "vocab.json")
480
+
481
+ with training_args.main_process_first():
482
+ if training_args.overwrite_output_dir and os.path.isfile(vocab_file):
483
+ os.remove(vocab_file)
484
+
485
+ with training_args.main_process_first(desc="dataset map vocabulary creation"):
486
+ if not os.path.isfile(vocab_file):
487
+ os.makedirs(tokenizer_name_or_path, exist_ok=True)
488
+ vocab_dict = create_vocabulary_from_data(
489
+ raw_datasets,
490
+ word_delimiter_token=word_delimiter_token,
491
+ unk_token=unk_token,
492
+ pad_token=pad_token,
493
+ )
494
+
495
+ # save vocab dict to be loaded into tokenizer
496
+ with open(vocab_file, "w") as file:
497
+ json.dump(vocab_dict, file)
498
+
499
+ # if tokenizer has just been created
500
+ # it is defined by `tokenizer_class` if present in config else by `model_type`
501
+ tokenizer_kwargs = {
502
+ "config": config if config.tokenizer_class is not None else None,
503
+ "tokenizer_type": config.model_type if config.tokenizer_class is None else None,
504
+ "unk_token": unk_token,
505
+ "pad_token": pad_token,
506
+ "word_delimiter_token": word_delimiter_token,
507
+ }
508
+
509
+ # 5. Now we can instantiate the feature extractor, tokenizer and model
510
+ # Note for distributed training, the .from_pretrained methods guarantee that only
511
+ # one local process can concurrently download model & vocab.
512
+
513
+ # load feature_extractor and tokenizer
514
+ tokenizer = AutoTokenizer.from_pretrained(
515
+ tokenizer_name_or_path,
516
+ use_auth_token=data_args.use_auth_token,
517
+ **tokenizer_kwargs,
518
+ )
519
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
520
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=data_args.use_auth_token
521
+ )
522
+
523
+ # adapt config
524
+ config.update(
525
+ {
526
+ "feat_proj_dropout": model_args.feat_proj_dropout,
527
+ "attention_dropout": model_args.attention_dropout,
528
+ "hidden_dropout": model_args.hidden_dropout,
529
+ "final_dropout": model_args.final_dropout,
530
+ "mask_time_prob": model_args.mask_time_prob,
531
+ "mask_time_length": model_args.mask_time_length,
532
+ "mask_feature_prob": model_args.mask_feature_prob,
533
+ "mask_feature_length": model_args.mask_feature_length,
534
+ "gradient_checkpointing": training_args.gradient_checkpointing,
535
+ "layerdrop": model_args.layerdrop,
536
+ "ctc_loss_reduction": model_args.ctc_loss_reduction,
537
+ "pad_token_id": tokenizer.pad_token_id,
538
+ "vocab_size": len(tokenizer),
539
+ "activation_dropout": model_args.activation_dropout,
540
+ }
541
+ )
542
+
543
+ # create model
544
+ model = AutoModelForCTC.from_pretrained(
545
+ model_args.model_name_or_path,
546
+ cache_dir=model_args.cache_dir,
547
+ config=config,
548
+ use_auth_token=data_args.use_auth_token,
549
+ )
550
+
551
+ # freeze encoder
552
+ if model_args.freeze_feature_encoder:
553
+ model.freeze_feature_encoder()
554
+
555
+ # 6. Now we preprocess the datasets including loading the audio, resampling and normalization
556
+ # Thankfully, `datasets` takes care of automatically loading and resampling the audio,
557
+ # so that we just need to set the correct target sampling rate and normalize the input
558
+ # via the `feature_extractor`
559
+
560
+ # make sure that dataset decodes audio with correct sampling rate
561
+ dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
562
+ if dataset_sampling_rate != feature_extractor.sampling_rate:
563
+ raw_datasets = raw_datasets.cast_column(
564
+ data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
565
+ )
566
+
567
+ # derive max & min input length for sample rate & max duration
568
+ max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
569
+ min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
570
+ audio_column_name = data_args.audio_column_name
571
+ num_workers = data_args.preprocessing_num_workers
572
+
573
+ # `phoneme_language` is only relevant if the model is fine-tuned on phoneme classification
574
+ phoneme_language = data_args.phoneme_language
575
+
576
+ # Preprocessing the datasets.
577
+ # We need to read the audio files as arrays and tokenize the targets.
578
+ def prepare_dataset(batch):
579
+ # load audio
580
+ sample = batch[audio_column_name]
581
+
582
+ inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
583
+ batch["input_values"] = inputs.input_values[0]
584
+ batch["input_length"] = len(batch["input_values"])
585
+
586
+ # encode targets
587
+ additional_kwargs = {}
588
+ if phoneme_language is not None:
589
+ additional_kwargs["phonemizer_lang"] = phoneme_language
590
+
591
+ batch["labels"] = tokenizer(batch["target_text"], **additional_kwargs).input_ids
592
+ return batch
593
+
594
+ with training_args.main_process_first(desc="dataset map preprocessing"):
595
+ vectorized_datasets = raw_datasets.map(
596
+ prepare_dataset,
597
+ remove_columns=next(iter(raw_datasets.values())).column_names,
598
+ num_proc=num_workers,
599
+ desc="preprocess datasets",
600
+ )
601
+
602
+ def is_audio_in_length_range(length):
603
+ return length > min_input_length and length < max_input_length
604
+
605
+ # filter data that is shorter than min_input_length
606
+ vectorized_datasets = vectorized_datasets.filter(
607
+ is_audio_in_length_range,
608
+ num_proc=num_workers,
609
+ input_columns=["input_length"],
610
+ )
611
+
612
+ # 7. Next, we can prepare the training.
613
+ # Let's use word error rate (WER) as our evaluation metric,
614
+ # instantiate a data collator and the trainer
615
+
616
+ # Define evaluation metrics during training, *i.e.* word error rate, character error rate
617
+ eval_metrics = {metric: load_metric(metric) for metric in data_args.eval_metrics}
618
+
619
+ # for large datasets it is advised to run the preprocessing on a
620
+ # single machine first with ``args.preprocessing_only`` since there will mostly likely
621
+ # be a timeout when running the script in distributed mode.
622
+ # In a second step ``args.preprocessing_only`` can then be set to `False` to load the
623
+ # cached dataset
624
+ if data_args.preprocessing_only:
625
+ logger.info(f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}")
626
+ return
627
+
628
+ def compute_metrics(pred):
629
+ pred_logits = pred.predictions
630
+ pred_ids = np.argmax(pred_logits, axis=-1)
631
+
632
+ pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id
633
+
634
+ pred_str = tokenizer.batch_decode(pred_ids)
635
+ # we do not want to group tokens when computing the metrics
636
+ label_str = tokenizer.batch_decode(pred.label_ids, group_tokens=False)
637
+
638
+ metrics = {k: v.compute(predictions=pred_str, references=label_str) for k, v in eval_metrics.items()}
639
+
640
+ return metrics
641
+
642
+ # Now save everything to be able to create a single processor later
643
+ if is_main_process(training_args.local_rank):
644
+ # save feature extractor, tokenizer and config
645
+ feature_extractor.save_pretrained(training_args.output_dir)
646
+ tokenizer.save_pretrained(training_args.output_dir)
647
+ config.save_pretrained(training_args.output_dir)
648
+
649
+ try:
650
+ processor = AutoProcessor.from_pretrained(training_args.output_dir)
651
+ except (OSError, KeyError):
652
+ warnings.warn(
653
+ "Loading a processor from a feature extractor config that does not"
654
+ " include a `processor_class` attribute is deprecated and will be removed in v5. Please add the following "
655
+ " attribute to your `preprocessor_config.json` file to suppress this warning: "
656
+ " `'processor_class': 'Wav2Vec2Processor'`",
657
+ FutureWarning,
658
+ )
659
+ processor = Wav2Vec2Processor.from_pretrained(training_args.output_dir)
660
+
661
+ # Instantiate custom data collator
662
+ data_collator = DataCollatorCTCWithPadding(processor=processor)
663
+
664
+ # Initialize Trainer
665
+ trainer = Trainer(
666
+ model=model,
667
+ data_collator=data_collator,
668
+ args=training_args,
669
+ compute_metrics=compute_metrics,
670
+ train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
671
+ eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
672
+ tokenizer=feature_extractor,
673
+ )
674
+
675
+ # 8. Finally, we can start training
676
+
677
+ # Training
678
+ if training_args.do_train:
679
+
680
+ # use last checkpoint if exist
681
+ if last_checkpoint is not None:
682
+ checkpoint = last_checkpoint
683
+ elif os.path.isdir(model_args.model_name_or_path):
684
+ checkpoint = model_args.model_name_or_path
685
+ else:
686
+ checkpoint = None
687
+
688
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
689
+ trainer.save_model()
690
+
691
+ metrics = train_result.metrics
692
+ max_train_samples = (
693
+ data_args.max_train_samples
694
+ if data_args.max_train_samples is not None
695
+ else len(vectorized_datasets["train"])
696
+ )
697
+ metrics["train_samples"] = min(max_train_samples, len(vectorized_datasets["train"]))
698
+
699
+ trainer.log_metrics("train", metrics)
700
+ trainer.save_metrics("train", metrics)
701
+ trainer.save_state()
702
+
703
+ # Evaluation
704
+ results = {}
705
+ if training_args.do_eval:
706
+ logger.info("*** Evaluate ***")
707
+ metrics = trainer.evaluate()
708
+ max_eval_samples = (
709
+ data_args.max_eval_samples if data_args.max_eval_samples is not None else len(vectorized_datasets["eval"])
710
+ )
711
+ metrics["eval_samples"] = min(max_eval_samples, len(vectorized_datasets["eval"]))
712
+
713
+ trainer.log_metrics("eval", metrics)
714
+ trainer.save_metrics("eval", metrics)
715
+
716
+ # Write model card and (optionally) push to hub
717
+ config_name = data_args.dataset_config_name if data_args.dataset_config_name is not None else "na"
718
+ kwargs = {
719
+ "finetuned_from": model_args.model_name_or_path,
720
+ "tasks": "speech-recognition",
721
+ "tags": ["automatic-speech-recognition", data_args.dataset_name],
722
+ "dataset_args": f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split: {data_args.eval_split_name}",
723
+ "dataset": f"{data_args.dataset_name.upper()} - {config_name.upper()}",
724
+ }
725
+ if "common_voice" in data_args.dataset_name:
726
+ kwargs["language"] = config_name
727
+
728
+ if training_args.push_to_hub:
729
+ trainer.push_to_hub(**kwargs)
730
+ else:
731
+ trainer.create_model_card(**kwargs)
732
+
733
+ return results
734
+
735
+
736
+ if __name__ == "__main__":
737
+ main()
runs/Jan26_11-22-42_job-df329b21-d243-4736-8f96-d11192aeb370/1643196546.6195295/events.out.tfevents.1643196546.job-df329b21-d243-4736-8f96-d11192aeb370.13810.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dccefa87dfb9f1a70748307d90324380c39d8b93d9af111a89d51ba81de70d24
3
+ size 4778
runs/Jan26_11-22-42_job-df329b21-d243-4736-8f96-d11192aeb370/events.out.tfevents.1643196546.job-df329b21-d243-4736-8f96-d11192aeb370.13810.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c1dd1de0119ae5b997385fc2a7feb994debd200ac3055775715f8355a07eb0d2
3
+ size 29217
runs/Jan26_11-22-42_job-df329b21-d243-4736-8f96-d11192aeb370/events.out.tfevents.1643291875.job-df329b21-d243-4736-8f96-d11192aeb370.13810.2 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cddb9d99a97b8c4323dfd96a1f6fef25c3c7ff9009b73aa856629224e0ce4325
3
+ size 405
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "[UNK]", "pad_token": "[PAD]", "additional_special_tokens": [{"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}]}
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
1
+ {"unk_token": "[UNK]", "bos_token": "<s>", "eos_token": "</s>", "pad_token": "[PAD]", "do_lower_case": false, "word_delimiter_token": "|", "special_tokens_map_file": null, "tokenizer_file": null, "name_or_path": "final", "tokenizer_class": "Wav2Vec2CTCTokenizer", "processor_class": "Wav2Vec2ProcessorWithLM"}
train_results.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 99.99,
3
+ "train_loss": 1.049089940994505,
4
+ "train_runtime": 95054.1856,
5
+ "train_samples": 10193,
6
+ "train_samples_per_second": 10.723,
7
+ "train_steps_per_second": 0.066
8
+ }
trainer_state.json ADDED
@@ -0,0 +1,901 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": null,
3
+ "best_model_checkpoint": null,
4
+ "epoch": 99.98823529411764,
5
+ "global_step": 6300,
6
+ "is_hyper_param_search": false,
7
+ "is_local_process_zero": true,
8
+ "is_world_process_zero": true,
9
+ "log_history": [
10
+ {
11
+ "epoch": 0.78,
12
+ "learning_rate": 4.9000000000000005e-06,
13
+ "loss": 11.1133,
14
+ "step": 50
15
+ },
16
+ {
17
+ "epoch": 1.58,
18
+ "learning_rate": 9.900000000000002e-06,
19
+ "loss": 3.3967,
20
+ "step": 100
21
+ },
22
+ {
23
+ "epoch": 2.38,
24
+ "learning_rate": 1.49e-05,
25
+ "loss": 3.2205,
26
+ "step": 150
27
+ },
28
+ {
29
+ "epoch": 3.17,
30
+ "learning_rate": 1.9900000000000003e-05,
31
+ "loss": 2.8143,
32
+ "step": 200
33
+ },
34
+ {
35
+ "epoch": 3.96,
36
+ "learning_rate": 2.4900000000000002e-05,
37
+ "loss": 1.9249,
38
+ "step": 250
39
+ },
40
+ {
41
+ "epoch": 4.75,
42
+ "learning_rate": 2.9900000000000002e-05,
43
+ "loss": 1.6708,
44
+ "step": 300
45
+ },
46
+ {
47
+ "epoch": 5.55,
48
+ "learning_rate": 3.49e-05,
49
+ "loss": 1.5501,
50
+ "step": 350
51
+ },
52
+ {
53
+ "epoch": 6.35,
54
+ "learning_rate": 3.99e-05,
55
+ "loss": 1.4258,
56
+ "step": 400
57
+ },
58
+ {
59
+ "epoch": 7.14,
60
+ "learning_rate": 4.49e-05,
61
+ "loss": 1.332,
62
+ "step": 450
63
+ },
64
+ {
65
+ "epoch": 7.93,
66
+ "learning_rate": 4.99e-05,
67
+ "loss": 1.2815,
68
+ "step": 500
69
+ },
70
+ {
71
+ "epoch": 7.93,
72
+ "eval_cer": 0.10093122852447588,
73
+ "eval_loss": 0.35359087586402893,
74
+ "eval_runtime": 257.0651,
75
+ "eval_samples_per_second": 16.852,
76
+ "eval_steps_per_second": 2.108,
77
+ "eval_wer": 0.47525724236188066,
78
+ "step": 500
79
+ },
80
+ {
81
+ "epoch": 8.72,
82
+ "learning_rate": 4.957758620689655e-05,
83
+ "loss": 1.2632,
84
+ "step": 550
85
+ },
86
+ {
87
+ "epoch": 9.52,
88
+ "learning_rate": 4.9146551724137934e-05,
89
+ "loss": 1.2239,
90
+ "step": 600
91
+ },
92
+ {
93
+ "epoch": 10.31,
94
+ "learning_rate": 4.871551724137931e-05,
95
+ "loss": 1.2044,
96
+ "step": 650
97
+ },
98
+ {
99
+ "epoch": 11.11,
100
+ "learning_rate": 4.828448275862069e-05,
101
+ "loss": 1.1918,
102
+ "step": 700
103
+ },
104
+ {
105
+ "epoch": 11.89,
106
+ "learning_rate": 4.785344827586207e-05,
107
+ "loss": 1.1641,
108
+ "step": 750
109
+ },
110
+ {
111
+ "epoch": 12.69,
112
+ "learning_rate": 4.742241379310345e-05,
113
+ "loss": 1.1718,
114
+ "step": 800
115
+ },
116
+ {
117
+ "epoch": 13.49,
118
+ "learning_rate": 4.699137931034483e-05,
119
+ "loss": 1.1638,
120
+ "step": 850
121
+ },
122
+ {
123
+ "epoch": 14.28,
124
+ "learning_rate": 4.656034482758621e-05,
125
+ "loss": 1.1317,
126
+ "step": 900
127
+ },
128
+ {
129
+ "epoch": 15.08,
130
+ "learning_rate": 4.612931034482759e-05,
131
+ "loss": 1.1334,
132
+ "step": 950
133
+ },
134
+ {
135
+ "epoch": 15.86,
136
+ "learning_rate": 4.569827586206897e-05,
137
+ "loss": 1.0869,
138
+ "step": 1000
139
+ },
140
+ {
141
+ "epoch": 15.86,
142
+ "eval_cer": 0.06135152631841044,
143
+ "eval_loss": 0.23165984451770782,
144
+ "eval_runtime": 262.4285,
145
+ "eval_samples_per_second": 16.507,
146
+ "eval_steps_per_second": 2.065,
147
+ "eval_wer": 0.3110653791356657,
148
+ "step": 1000
149
+ },
150
+ {
151
+ "epoch": 16.66,
152
+ "learning_rate": 4.526724137931035e-05,
153
+ "loss": 1.104,
154
+ "step": 1050
155
+ },
156
+ {
157
+ "epoch": 17.45,
158
+ "learning_rate": 4.4836206896551726e-05,
159
+ "loss": 1.109,
160
+ "step": 1100
161
+ },
162
+ {
163
+ "epoch": 18.25,
164
+ "learning_rate": 4.440517241379311e-05,
165
+ "loss": 1.0902,
166
+ "step": 1150
167
+ },
168
+ {
169
+ "epoch": 19.05,
170
+ "learning_rate": 4.397413793103449e-05,
171
+ "loss": 1.0676,
172
+ "step": 1200
173
+ },
174
+ {
175
+ "epoch": 19.83,
176
+ "learning_rate": 4.3543103448275865e-05,
177
+ "loss": 1.0453,
178
+ "step": 1250
179
+ },
180
+ {
181
+ "epoch": 20.63,
182
+ "learning_rate": 4.311206896551725e-05,
183
+ "loss": 1.0489,
184
+ "step": 1300
185
+ },
186
+ {
187
+ "epoch": 21.42,
188
+ "learning_rate": 4.268103448275862e-05,
189
+ "loss": 1.0495,
190
+ "step": 1350
191
+ },
192
+ {
193
+ "epoch": 22.22,
194
+ "learning_rate": 4.2250000000000004e-05,
195
+ "loss": 1.0325,
196
+ "step": 1400
197
+ },
198
+ {
199
+ "epoch": 23.02,
200
+ "learning_rate": 4.181896551724138e-05,
201
+ "loss": 1.0298,
202
+ "step": 1450
203
+ },
204
+ {
205
+ "epoch": 23.8,
206
+ "learning_rate": 4.138793103448276e-05,
207
+ "loss": 0.9984,
208
+ "step": 1500
209
+ },
210
+ {
211
+ "epoch": 23.8,
212
+ "eval_cer": 0.052054180568696776,
213
+ "eval_loss": 0.20215292274951935,
214
+ "eval_runtime": 259.0562,
215
+ "eval_samples_per_second": 16.722,
216
+ "eval_steps_per_second": 2.092,
217
+ "eval_wer": 0.26762703815102107,
218
+ "step": 1500
219
+ },
220
+ {
221
+ "epoch": 24.6,
222
+ "learning_rate": 4.0956896551724136e-05,
223
+ "loss": 1.0118,
224
+ "step": 1550
225
+ },
226
+ {
227
+ "epoch": 25.39,
228
+ "learning_rate": 4.053448275862069e-05,
229
+ "loss": 1.0165,
230
+ "step": 1600
231
+ },
232
+ {
233
+ "epoch": 26.19,
234
+ "learning_rate": 4.0103448275862074e-05,
235
+ "loss": 1.0075,
236
+ "step": 1650
237
+ },
238
+ {
239
+ "epoch": 26.97,
240
+ "learning_rate": 3.967241379310345e-05,
241
+ "loss": 1.003,
242
+ "step": 1700
243
+ },
244
+ {
245
+ "epoch": 27.77,
246
+ "learning_rate": 3.924137931034483e-05,
247
+ "loss": 0.9905,
248
+ "step": 1750
249
+ },
250
+ {
251
+ "epoch": 28.56,
252
+ "learning_rate": 3.8810344827586206e-05,
253
+ "loss": 1.0019,
254
+ "step": 1800
255
+ },
256
+ {
257
+ "epoch": 29.36,
258
+ "learning_rate": 3.837931034482759e-05,
259
+ "loss": 1.0085,
260
+ "step": 1850
261
+ },
262
+ {
263
+ "epoch": 30.16,
264
+ "learning_rate": 3.794827586206896e-05,
265
+ "loss": 0.9868,
266
+ "step": 1900
267
+ },
268
+ {
269
+ "epoch": 30.94,
270
+ "learning_rate": 3.7517241379310345e-05,
271
+ "loss": 0.9816,
272
+ "step": 1950
273
+ },
274
+ {
275
+ "epoch": 31.74,
276
+ "learning_rate": 3.708620689655173e-05,
277
+ "loss": 0.975,
278
+ "step": 2000
279
+ },
280
+ {
281
+ "epoch": 31.74,
282
+ "eval_cer": 0.04868781435187491,
283
+ "eval_loss": 0.19483695924282074,
284
+ "eval_runtime": 259.0795,
285
+ "eval_samples_per_second": 16.721,
286
+ "eval_steps_per_second": 2.092,
287
+ "eval_wer": 0.24688934620864333,
288
+ "step": 2000
289
+ },
290
+ {
291
+ "epoch": 32.53,
292
+ "learning_rate": 3.66551724137931e-05,
293
+ "loss": 0.9552,
294
+ "step": 2050
295
+ },
296
+ {
297
+ "epoch": 33.33,
298
+ "learning_rate": 3.6224137931034484e-05,
299
+ "loss": 0.9649,
300
+ "step": 2100
301
+ },
302
+ {
303
+ "epoch": 34.13,
304
+ "learning_rate": 3.5793103448275866e-05,
305
+ "loss": 0.9632,
306
+ "step": 2150
307
+ },
308
+ {
309
+ "epoch": 34.91,
310
+ "learning_rate": 3.536206896551724e-05,
311
+ "loss": 0.9542,
312
+ "step": 2200
313
+ },
314
+ {
315
+ "epoch": 35.71,
316
+ "learning_rate": 3.493103448275862e-05,
317
+ "loss": 0.9686,
318
+ "step": 2250
319
+ },
320
+ {
321
+ "epoch": 36.5,
322
+ "learning_rate": 3.45e-05,
323
+ "loss": 0.9418,
324
+ "step": 2300
325
+ },
326
+ {
327
+ "epoch": 37.3,
328
+ "learning_rate": 3.406896551724138e-05,
329
+ "loss": 0.9295,
330
+ "step": 2350
331
+ },
332
+ {
333
+ "epoch": 38.09,
334
+ "learning_rate": 3.363793103448276e-05,
335
+ "loss": 0.9372,
336
+ "step": 2400
337
+ },
338
+ {
339
+ "epoch": 38.88,
340
+ "learning_rate": 3.320689655172414e-05,
341
+ "loss": 0.9205,
342
+ "step": 2450
343
+ },
344
+ {
345
+ "epoch": 39.67,
346
+ "learning_rate": 3.277586206896552e-05,
347
+ "loss": 0.9306,
348
+ "step": 2500
349
+ },
350
+ {
351
+ "epoch": 39.67,
352
+ "eval_cer": 0.046377172451571136,
353
+ "eval_loss": 0.19161736965179443,
354
+ "eval_runtime": 258.3157,
355
+ "eval_samples_per_second": 16.77,
356
+ "eval_steps_per_second": 2.098,
357
+ "eval_wer": 0.2377394332752889,
358
+ "step": 2500
359
+ },
360
+ {
361
+ "epoch": 40.47,
362
+ "learning_rate": 3.23448275862069e-05,
363
+ "loss": 0.9331,
364
+ "step": 2550
365
+ },
366
+ {
367
+ "epoch": 41.27,
368
+ "learning_rate": 3.1913793103448276e-05,
369
+ "loss": 0.8936,
370
+ "step": 2600
371
+ },
372
+ {
373
+ "epoch": 42.06,
374
+ "learning_rate": 3.148275862068966e-05,
375
+ "loss": 0.8987,
376
+ "step": 2650
377
+ },
378
+ {
379
+ "epoch": 42.85,
380
+ "learning_rate": 3.105172413793104e-05,
381
+ "loss": 0.8853,
382
+ "step": 2700
383
+ },
384
+ {
385
+ "epoch": 43.64,
386
+ "learning_rate": 3.0620689655172415e-05,
387
+ "loss": 0.9106,
388
+ "step": 2750
389
+ },
390
+ {
391
+ "epoch": 44.44,
392
+ "learning_rate": 3.0189655172413794e-05,
393
+ "loss": 0.8932,
394
+ "step": 2800
395
+ },
396
+ {
397
+ "epoch": 45.24,
398
+ "learning_rate": 2.9758620689655176e-05,
399
+ "loss": 0.9096,
400
+ "step": 2850
401
+ },
402
+ {
403
+ "epoch": 46.03,
404
+ "learning_rate": 2.932758620689655e-05,
405
+ "loss": 0.8919,
406
+ "step": 2900
407
+ },
408
+ {
409
+ "epoch": 46.82,
410
+ "learning_rate": 2.8896551724137933e-05,
411
+ "loss": 0.8744,
412
+ "step": 2950
413
+ },
414
+ {
415
+ "epoch": 47.61,
416
+ "learning_rate": 2.8465517241379315e-05,
417
+ "loss": 0.8868,
418
+ "step": 3000
419
+ },
420
+ {
421
+ "epoch": 47.61,
422
+ "eval_cer": 0.04391713560081669,
423
+ "eval_loss": 0.19031885266304016,
424
+ "eval_runtime": 265.4438,
425
+ "eval_samples_per_second": 16.32,
426
+ "eval_steps_per_second": 2.042,
427
+ "eval_wer": 0.2257400664872566,
428
+ "step": 3000
429
+ },
430
+ {
431
+ "epoch": 48.41,
432
+ "learning_rate": 2.803448275862069e-05,
433
+ "loss": 0.8793,
434
+ "step": 3050
435
+ },
436
+ {
437
+ "epoch": 49.2,
438
+ "learning_rate": 2.7603448275862072e-05,
439
+ "loss": 0.8739,
440
+ "step": 3100
441
+ },
442
+ {
443
+ "epoch": 49.99,
444
+ "learning_rate": 2.717241379310345e-05,
445
+ "loss": 0.8696,
446
+ "step": 3150
447
+ },
448
+ {
449
+ "epoch": 50.78,
450
+ "learning_rate": 2.674137931034483e-05,
451
+ "loss": 0.863,
452
+ "step": 3200
453
+ },
454
+ {
455
+ "epoch": 51.58,
456
+ "learning_rate": 2.6310344827586207e-05,
457
+ "loss": 0.8612,
458
+ "step": 3250
459
+ },
460
+ {
461
+ "epoch": 52.38,
462
+ "learning_rate": 2.587931034482759e-05,
463
+ "loss": 0.8639,
464
+ "step": 3300
465
+ },
466
+ {
467
+ "epoch": 53.17,
468
+ "learning_rate": 2.5448275862068964e-05,
469
+ "loss": 0.8523,
470
+ "step": 3350
471
+ },
472
+ {
473
+ "epoch": 53.96,
474
+ "learning_rate": 2.5017241379310346e-05,
475
+ "loss": 0.8577,
476
+ "step": 3400
477
+ },
478
+ {
479
+ "epoch": 54.75,
480
+ "learning_rate": 2.4586206896551725e-05,
481
+ "loss": 0.8465,
482
+ "step": 3450
483
+ },
484
+ {
485
+ "epoch": 55.55,
486
+ "learning_rate": 2.4155172413793103e-05,
487
+ "loss": 0.8424,
488
+ "step": 3500
489
+ },
490
+ {
491
+ "epoch": 55.55,
492
+ "eval_cer": 0.042293710472586024,
493
+ "eval_loss": 0.17861121892929077,
494
+ "eval_runtime": 259.0288,
495
+ "eval_samples_per_second": 16.724,
496
+ "eval_steps_per_second": 2.092,
497
+ "eval_wer": 0.22061104954883648,
498
+ "step": 3500
499
+ },
500
+ {
501
+ "epoch": 56.35,
502
+ "learning_rate": 2.3724137931034485e-05,
503
+ "loss": 0.8436,
504
+ "step": 3550
505
+ },
506
+ {
507
+ "epoch": 57.14,
508
+ "learning_rate": 2.3293103448275864e-05,
509
+ "loss": 0.8404,
510
+ "step": 3600
511
+ },
512
+ {
513
+ "epoch": 57.93,
514
+ "learning_rate": 2.2862068965517242e-05,
515
+ "loss": 0.8304,
516
+ "step": 3650
517
+ },
518
+ {
519
+ "epoch": 58.72,
520
+ "learning_rate": 2.2431034482758624e-05,
521
+ "loss": 0.8331,
522
+ "step": 3700
523
+ },
524
+ {
525
+ "epoch": 59.52,
526
+ "learning_rate": 2.2000000000000003e-05,
527
+ "loss": 0.824,
528
+ "step": 3750
529
+ },
530
+ {
531
+ "epoch": 60.31,
532
+ "learning_rate": 2.1568965517241378e-05,
533
+ "loss": 0.8328,
534
+ "step": 3800
535
+ },
536
+ {
537
+ "epoch": 61.11,
538
+ "learning_rate": 2.113793103448276e-05,
539
+ "loss": 0.8234,
540
+ "step": 3850
541
+ },
542
+ {
543
+ "epoch": 61.89,
544
+ "learning_rate": 2.070689655172414e-05,
545
+ "loss": 0.8098,
546
+ "step": 3900
547
+ },
548
+ {
549
+ "epoch": 62.69,
550
+ "learning_rate": 2.0275862068965517e-05,
551
+ "loss": 0.8287,
552
+ "step": 3950
553
+ },
554
+ {
555
+ "epoch": 63.49,
556
+ "learning_rate": 1.98448275862069e-05,
557
+ "loss": 0.8126,
558
+ "step": 4000
559
+ },
560
+ {
561
+ "epoch": 63.49,
562
+ "eval_cer": 0.04164135252228475,
563
+ "eval_loss": 0.18486249446868896,
564
+ "eval_runtime": 261.7127,
565
+ "eval_samples_per_second": 16.553,
566
+ "eval_steps_per_second": 2.071,
567
+ "eval_wer": 0.2159886021845813,
568
+ "step": 4000
569
+ },
570
+ {
571
+ "epoch": 64.28,
572
+ "learning_rate": 1.9413793103448277e-05,
573
+ "loss": 0.8089,
574
+ "step": 4050
575
+ },
576
+ {
577
+ "epoch": 65.08,
578
+ "learning_rate": 1.8982758620689656e-05,
579
+ "loss": 0.8126,
580
+ "step": 4100
581
+ },
582
+ {
583
+ "epoch": 65.86,
584
+ "learning_rate": 1.8551724137931034e-05,
585
+ "loss": 0.7975,
586
+ "step": 4150
587
+ },
588
+ {
589
+ "epoch": 66.66,
590
+ "learning_rate": 1.8120689655172416e-05,
591
+ "loss": 0.8049,
592
+ "step": 4200
593
+ },
594
+ {
595
+ "epoch": 67.45,
596
+ "learning_rate": 1.7698275862068966e-05,
597
+ "loss": 0.8088,
598
+ "step": 4250
599
+ },
600
+ {
601
+ "epoch": 68.25,
602
+ "learning_rate": 1.7267241379310344e-05,
603
+ "loss": 0.8038,
604
+ "step": 4300
605
+ },
606
+ {
607
+ "epoch": 69.05,
608
+ "learning_rate": 1.6836206896551726e-05,
609
+ "loss": 0.7886,
610
+ "step": 4350
611
+ },
612
+ {
613
+ "epoch": 69.83,
614
+ "learning_rate": 1.6405172413793105e-05,
615
+ "loss": 0.7735,
616
+ "step": 4400
617
+ },
618
+ {
619
+ "epoch": 70.63,
620
+ "learning_rate": 1.5974137931034483e-05,
621
+ "loss": 0.7837,
622
+ "step": 4450
623
+ },
624
+ {
625
+ "epoch": 71.42,
626
+ "learning_rate": 1.5543103448275865e-05,
627
+ "loss": 0.7901,
628
+ "step": 4500
629
+ },
630
+ {
631
+ "epoch": 71.42,
632
+ "eval_cer": 0.04126786514615806,
633
+ "eval_loss": 0.18691900372505188,
634
+ "eval_runtime": 261.5464,
635
+ "eval_samples_per_second": 16.563,
636
+ "eval_steps_per_second": 2.072,
637
+ "eval_wer": 0.21383568149438023,
638
+ "step": 4500
639
+ },
640
+ {
641
+ "epoch": 72.22,
642
+ "learning_rate": 1.5112068965517242e-05,
643
+ "loss": 0.7949,
644
+ "step": 4550
645
+ },
646
+ {
647
+ "epoch": 73.02,
648
+ "learning_rate": 1.468103448275862e-05,
649
+ "loss": 0.7893,
650
+ "step": 4600
651
+ },
652
+ {
653
+ "epoch": 73.8,
654
+ "learning_rate": 1.4249999999999999e-05,
655
+ "loss": 0.7603,
656
+ "step": 4650
657
+ },
658
+ {
659
+ "epoch": 74.6,
660
+ "learning_rate": 1.3818965517241381e-05,
661
+ "loss": 0.776,
662
+ "step": 4700
663
+ },
664
+ {
665
+ "epoch": 75.39,
666
+ "learning_rate": 1.338793103448276e-05,
667
+ "loss": 0.7755,
668
+ "step": 4750
669
+ },
670
+ {
671
+ "epoch": 76.19,
672
+ "learning_rate": 1.2956896551724138e-05,
673
+ "loss": 0.7751,
674
+ "step": 4800
675
+ },
676
+ {
677
+ "epoch": 76.97,
678
+ "learning_rate": 1.2525862068965518e-05,
679
+ "loss": 0.7608,
680
+ "step": 4850
681
+ },
682
+ {
683
+ "epoch": 77.77,
684
+ "learning_rate": 1.2094827586206897e-05,
685
+ "loss": 0.7663,
686
+ "step": 4900
687
+ },
688
+ {
689
+ "epoch": 78.56,
690
+ "learning_rate": 1.1663793103448277e-05,
691
+ "loss": 0.7656,
692
+ "step": 4950
693
+ },
694
+ {
695
+ "epoch": 79.36,
696
+ "learning_rate": 1.1232758620689656e-05,
697
+ "loss": 0.7671,
698
+ "step": 5000
699
+ },
700
+ {
701
+ "epoch": 79.36,
702
+ "eval_cer": 0.03937054927543449,
703
+ "eval_loss": 0.18550464510917664,
704
+ "eval_runtime": 260.3539,
705
+ "eval_samples_per_second": 16.639,
706
+ "eval_steps_per_second": 2.082,
707
+ "eval_wer": 0.20747190121893302,
708
+ "step": 5000
709
+ },
710
+ {
711
+ "epoch": 80.16,
712
+ "learning_rate": 1.0801724137931036e-05,
713
+ "loss": 0.7694,
714
+ "step": 5050
715
+ },
716
+ {
717
+ "epoch": 80.94,
718
+ "learning_rate": 1.0370689655172414e-05,
719
+ "loss": 0.7672,
720
+ "step": 5100
721
+ },
722
+ {
723
+ "epoch": 81.74,
724
+ "learning_rate": 9.939655172413793e-06,
725
+ "loss": 0.7444,
726
+ "step": 5150
727
+ },
728
+ {
729
+ "epoch": 82.53,
730
+ "learning_rate": 9.508620689655173e-06,
731
+ "loss": 0.7534,
732
+ "step": 5200
733
+ },
734
+ {
735
+ "epoch": 83.33,
736
+ "learning_rate": 9.077586206896552e-06,
737
+ "loss": 0.7453,
738
+ "step": 5250
739
+ },
740
+ {
741
+ "epoch": 84.13,
742
+ "learning_rate": 8.646551724137932e-06,
743
+ "loss": 0.7494,
744
+ "step": 5300
745
+ },
746
+ {
747
+ "epoch": 84.91,
748
+ "learning_rate": 8.224137931034483e-06,
749
+ "loss": 0.7425,
750
+ "step": 5350
751
+ },
752
+ {
753
+ "epoch": 85.71,
754
+ "learning_rate": 7.793103448275863e-06,
755
+ "loss": 0.7499,
756
+ "step": 5400
757
+ },
758
+ {
759
+ "epoch": 86.5,
760
+ "learning_rate": 7.370689655172413e-06,
761
+ "loss": 0.735,
762
+ "step": 5450
763
+ },
764
+ {
765
+ "epoch": 87.3,
766
+ "learning_rate": 6.939655172413794e-06,
767
+ "loss": 0.7467,
768
+ "step": 5500
769
+ },
770
+ {
771
+ "epoch": 87.3,
772
+ "eval_cer": 0.03894228375080922,
773
+ "eval_loss": 0.18841499090194702,
774
+ "eval_runtime": 261.43,
775
+ "eval_samples_per_second": 16.57,
776
+ "eval_steps_per_second": 2.073,
777
+ "eval_wer": 0.20490739274972297,
778
+ "step": 5500
779
+ },
780
+ {
781
+ "epoch": 88.09,
782
+ "learning_rate": 6.508620689655173e-06,
783
+ "loss": 0.7348,
784
+ "step": 5550
785
+ },
786
+ {
787
+ "epoch": 88.88,
788
+ "learning_rate": 6.0775862068965515e-06,
789
+ "loss": 0.7244,
790
+ "step": 5600
791
+ },
792
+ {
793
+ "epoch": 89.67,
794
+ "learning_rate": 5.646551724137932e-06,
795
+ "loss": 0.7394,
796
+ "step": 5650
797
+ },
798
+ {
799
+ "epoch": 90.47,
800
+ "learning_rate": 5.21551724137931e-06,
801
+ "loss": 0.7423,
802
+ "step": 5700
803
+ },
804
+ {
805
+ "epoch": 91.27,
806
+ "learning_rate": 4.78448275862069e-06,
807
+ "loss": 0.7251,
808
+ "step": 5750
809
+ },
810
+ {
811
+ "epoch": 92.06,
812
+ "learning_rate": 4.353448275862069e-06,
813
+ "loss": 0.7304,
814
+ "step": 5800
815
+ },
816
+ {
817
+ "epoch": 92.85,
818
+ "learning_rate": 3.9224137931034484e-06,
819
+ "loss": 0.7153,
820
+ "step": 5850
821
+ },
822
+ {
823
+ "epoch": 93.64,
824
+ "learning_rate": 3.491379310344828e-06,
825
+ "loss": 0.7287,
826
+ "step": 5900
827
+ },
828
+ {
829
+ "epoch": 94.44,
830
+ "learning_rate": 3.0603448275862068e-06,
831
+ "loss": 0.7349,
832
+ "step": 5950
833
+ },
834
+ {
835
+ "epoch": 95.24,
836
+ "learning_rate": 2.6293103448275866e-06,
837
+ "loss": 0.731,
838
+ "step": 6000
839
+ },
840
+ {
841
+ "epoch": 95.24,
842
+ "eval_cer": 0.03871819132513321,
843
+ "eval_loss": 0.1877404898405075,
844
+ "eval_runtime": 259.3367,
845
+ "eval_samples_per_second": 16.704,
846
+ "eval_steps_per_second": 2.09,
847
+ "eval_wer": 0.2059838530948235,
848
+ "step": 6000
849
+ },
850
+ {
851
+ "epoch": 96.03,
852
+ "learning_rate": 2.1982758620689655e-06,
853
+ "loss": 0.7151,
854
+ "step": 6050
855
+ },
856
+ {
857
+ "epoch": 96.82,
858
+ "learning_rate": 1.7672413793103449e-06,
859
+ "loss": 0.713,
860
+ "step": 6100
861
+ },
862
+ {
863
+ "epoch": 97.61,
864
+ "learning_rate": 1.3362068965517243e-06,
865
+ "loss": 0.7257,
866
+ "step": 6150
867
+ },
868
+ {
869
+ "epoch": 98.41,
870
+ "learning_rate": 9.051724137931035e-07,
871
+ "loss": 0.7287,
872
+ "step": 6200
873
+ },
874
+ {
875
+ "epoch": 99.2,
876
+ "learning_rate": 4.7413793103448276e-07,
877
+ "loss": 0.7273,
878
+ "step": 6250
879
+ },
880
+ {
881
+ "epoch": 99.99,
882
+ "learning_rate": 4.310344827586207e-08,
883
+ "loss": 0.7082,
884
+ "step": 6300
885
+ },
886
+ {
887
+ "epoch": 99.99,
888
+ "step": 6300,
889
+ "total_flos": 4.0887685530877926e+20,
890
+ "train_loss": 1.049089940994505,
891
+ "train_runtime": 95054.1856,
892
+ "train_samples_per_second": 10.723,
893
+ "train_steps_per_second": 0.066
894
+ }
895
+ ],
896
+ "max_steps": 6300,
897
+ "num_train_epochs": 100,
898
+ "total_flos": 4.0887685530877926e+20,
899
+ "trial_name": null,
900
+ "trial_params": null
901
+ }
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d28ad45d3721baa2d9d1ae37754a308d7348290cd7784d5d9c00aa9c1ffe0061
3
+ size 3055
vocab.json ADDED
@@ -0,0 +1 @@
 
1
+ {"a": 1, "c": 2, "e": 3, "i": 4, "j": 5, "k": 6, "l": 7, "m": 8, "n": 9, "o": 10, "p": 11, "u": 12, "x": 13, "y": 14, "́": 15, "а": 16, "б": 17, "в": 18, "г": 19, "д": 20, "е": 21, "ж": 22, "з": 23, "и": 24, "й": 25, "к": 26, "л": 27, "м": 28, "н": 29, "о": 30, "п": 31, "р": 32, "с": 33, "т": 34, "у": 35, "ф": 36, "х": 37, "ц": 38, "ч": 39, "ш": 40, "щ": 41, "ь": 42, "ю": 43, "я": 44, "є": 45, "і": 46, "ї": 47, "ґ": 48, "|": 0, "[UNK]": 49, "[PAD]": 50}