nguyenvulebinh commited on
Commit
778e524
1 Parent(s): 1155872

add init code train

Browse files
.gitignore ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Created by https://www.toptal.com/developers/gitignore/api/pycharm,python
3
+ # Edit at https://www.toptal.com/developers/gitignore?templates=pycharm,python
4
+
5
+ ### PyCharm ###
6
+ # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
7
+ # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
8
+
9
+ # User-specific stuff
10
+ .idea/**/workspace.xml
11
+ .idea/**/tasks.xml
12
+ .idea/**/usage.statistics.xml
13
+ .idea/**/dictionaries
14
+ .idea/**/shelf
15
+
16
+ # AWS User-specific
17
+ .idea/**/aws.xml
18
+
19
+ # Generated files
20
+ .idea/**/contentModel.xml
21
+
22
+ # Sensitive or high-churn files
23
+ .idea/**/dataSources/
24
+ .idea/**/dataSources.ids
25
+ .idea/**/dataSources.local.xml
26
+ .idea/**/sqlDataSources.xml
27
+ .idea/**/dynamic.xml
28
+ .idea/**/uiDesigner.xml
29
+ .idea/**/dbnavigator.xml
30
+
31
+ # Gradle
32
+ .idea/**/gradle.xml
33
+ .idea/**/libraries
34
+
35
+ # Gradle and Maven with auto-import
36
+ # When using Gradle or Maven with auto-import, you should exclude module files,
37
+ # since they will be recreated, and may cause churn. Uncomment if using
38
+ # auto-import.
39
+ # .idea/artifacts
40
+ # .idea/compiler.xml
41
+ # .idea/jarRepositories.xml
42
+ # .idea/modules.xml
43
+ # .idea/*.iml
44
+ # .idea/modules
45
+ # *.iml
46
+ # *.ipr
47
+
48
+ # CMake
49
+ cmake-build-*/
50
+
51
+ # Mongo Explorer plugin
52
+ .idea/**/mongoSettings.xml
53
+
54
+ # File-based project format
55
+ *.iws
56
+
57
+ # IntelliJ
58
+ out/
59
+
60
+ # mpeltonen/sbt-idea plugin
61
+ .idea_modules/
62
+
63
+ # JIRA plugin
64
+ atlassian-ide-plugin.xml
65
+
66
+ # Cursive Clojure plugin
67
+ .idea/replstate.xml
68
+
69
+ # Crashlytics plugin (for Android Studio and IntelliJ)
70
+ com_crashlytics_export_strings.xml
71
+ crashlytics.properties
72
+ crashlytics-build.properties
73
+ fabric.properties
74
+
75
+ # Editor-based Rest Client
76
+ .idea/httpRequests
77
+
78
+ # Android studio 3.1+ serialized cache file
79
+ .idea/caches/build_file_checksums.ser
80
+
81
+ ### PyCharm Patch ###
82
+ # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721
83
+
84
+ # *.iml
85
+ # modules.xml
86
+ # .idea/misc.xml
87
+ # *.ipr
88
+
89
+ # Sonarlint plugin
90
+ # https://plugins.jetbrains.com/plugin/7973-sonarlint
91
+ .idea/**/sonarlint/
92
+
93
+ # SonarQube Plugin
94
+ # https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin
95
+ .idea/**/sonarIssues.xml
96
+
97
+ # Markdown Navigator plugin
98
+ # https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced
99
+ .idea/**/markdown-navigator.xml
100
+ .idea/**/markdown-navigator-enh.xml
101
+ .idea/**/markdown-navigator/
102
+
103
+ # Cache file creation bug
104
+ # See https://youtrack.jetbrains.com/issue/JBR-2257
105
+ .idea/$CACHE_FILE$
106
+
107
+ # CodeStream plugin
108
+ # https://plugins.jetbrains.com/plugin/12206-codestream
109
+ .idea/codestream.xml
110
+
111
+ ### Python ###
112
+ # Byte-compiled / optimized / DLL files
113
+ __pycache__/
114
+ *.py[cod]
115
+ *$py.class
116
+
117
+ # C extensions
118
+ *.so
119
+
120
+ # Distribution / packaging
121
+ .Python
122
+ build/
123
+ develop-eggs/
124
+ dist/
125
+ downloads/
126
+ eggs/
127
+ .eggs/
128
+ lib/
129
+ lib64/
130
+ parts/
131
+ sdist/
132
+ var/
133
+ wheels/
134
+ share/python-wheels/
135
+ *.egg-info/
136
+ .installed.cfg
137
+ *.egg
138
+ MANIFEST
139
+
140
+ # PyInstaller
141
+ # Usually these files are written by a python script from a template
142
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
143
+ *.manifest
144
+ *.spec
145
+
146
+ # Installer logs
147
+ pip-log.txt
148
+ pip-delete-this-directory.txt
149
+
150
+ # Unit test / coverage reports
151
+ htmlcov/
152
+ .tox/
153
+ .nox/
154
+ .coverage
155
+ .coverage.*
156
+ .cache
157
+ nosetests.xml
158
+ coverage.xml
159
+ *.cover
160
+ *.py,cover
161
+ .hypothesis/
162
+ .pytest_cache/
163
+ cover/
164
+
165
+ # Translations
166
+ *.mo
167
+ *.pot
168
+
169
+ # Django stuff:
170
+ *.log
171
+ local_settings.py
172
+ db.sqlite3
173
+ db.sqlite3-journal
174
+
175
+ # Flask stuff:
176
+ instance/
177
+ .webassets-cache
178
+
179
+ # Scrapy stuff:
180
+ .scrapy
181
+
182
+ # Sphinx documentation
183
+ docs/_build/
184
+
185
+ # PyBuilder
186
+ .pybuilder/
187
+ target/
188
+
189
+ # Jupyter Notebook
190
+ .ipynb_checkpoints
191
+
192
+ # IPython
193
+ profile_default/
194
+ ipython_config.py
195
+
196
+ # pyenv
197
+ # For a library or package, you might want to ignore these files since the code is
198
+ # intended to run in multiple environments; otherwise, check them in:
199
+ # .python-version
200
+
201
+ # pipenv
202
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
203
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
204
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
205
+ # install all needed dependencies.
206
+ #Pipfile.lock
207
+
208
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
209
+ __pypackages__/
210
+
211
+ # Celery stuff
212
+ celerybeat-schedule
213
+ celerybeat.pid
214
+
215
+ # SageMath parsed files
216
+ *.sage.py
217
+
218
+ # Environments
219
+ .env
220
+ .venv
221
+ env/
222
+ venv/
223
+ ENV/
224
+ env.bak/
225
+ venv.bak/
226
+
227
+ # Spyder project settings
228
+ .spyderproject
229
+ .spyproject
230
+
231
+ # Rope project settings
232
+ .ropeproject
233
+
234
+ # mkdocs documentation
235
+ /site
236
+
237
+ # mypy
238
+ .mypy_cache/
239
+ .dmypy.json
240
+ dmypy.json
241
+
242
+ # Pyre type checker
243
+ .pyre/
244
+
245
+ # pytype static type analyzer
246
+ .pytype/
247
+
248
+ # Cython debug symbols
249
+ cython_debug/
250
+
251
+ # End of https://www.toptal.com/developers/gitignore/api/pycharm,python
252
+
253
+
254
+ data-bin/
255
+
256
+ .DS_Store
257
+
258
+ .idea/
data_handler.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Any, Dict, List, Optional, Union
5
+ from transformers import Wav2Vec2Processor
6
+
7
+
8
+ @dataclass
9
+ class DataCollatorCTCWithPadding:
10
+ """
11
+ Data collator that will dynamically pad the inputs received.
12
+ Args:
13
+ processor (:class:`~transformers.Wav2Vec2Processor`)
14
+ The processor used for proccessing the data.
15
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
16
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
17
+ among:
18
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
19
+ sequence if provided).
20
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
21
+ maximum acceptable input length for the model if that argument is not provided.
22
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
23
+ different lengths).
24
+ max_length (:obj:`int`, `optional`):
25
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
26
+ max_length_labels (:obj:`int`, `optional`):
27
+ Maximum length of the ``labels`` returned list and optionally padding length (see above).
28
+ pad_to_multiple_of (:obj:`int`, `optional`):
29
+ If set will pad the sequence to a multiple of the provided value.
30
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
31
+ 7.5 (Volta).
32
+ """
33
+
34
+ processor: Wav2Vec2Processor
35
+ padding: Union[bool, str] = True
36
+ max_length: Optional[int] = None
37
+ max_length_labels: Optional[int] = None
38
+ pad_to_multiple_of: Optional[int] = None
39
+ pad_to_multiple_of_labels: Optional[int] = None
40
+
41
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
42
+ # split inputs and labels since they have to be of different lenghts and need
43
+ # different padding methods
44
+ input_features = [{"input_values": feature["input_values"]} for feature in features]
45
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
46
+
47
+ batch = self.processor.pad(
48
+ input_features,
49
+ padding=self.padding,
50
+ max_length=self.max_length,
51
+ pad_to_multiple_of=self.pad_to_multiple_of,
52
+ return_tensors="pt",
53
+ )
54
+ with self.processor.as_target_processor():
55
+ labels_batch = self.processor.pad(
56
+ label_features,
57
+ padding=self.padding,
58
+ max_length=self.max_length_labels,
59
+ pad_to_multiple_of=self.pad_to_multiple_of_labels,
60
+ return_tensors="pt",
61
+ )
62
+
63
+ # replace padding with -100 to ignore loss correctly
64
+ labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
65
+
66
+ batch["labels"] = labels
67
+
68
+ return batch
main.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
2
+ from datasets import load_from_disk
3
+ from data_handler import DataCollatorCTCWithPadding
4
+ from transformers import TrainingArguments
5
+ from transformers import Trainer, logging
6
+ from metric_utils import compute_metrics_fn
7
+ from transformers.trainer_utils import get_last_checkpoint
8
+ import json, random
9
+ import os, glob
10
+
11
+ logging.set_verbosity_info()
12
+
13
+
14
+ def load_pretrained_model(checkpoint_path=None):
15
+ if checkpoint_path is None:
16
+ pre_trained_path = './model-bin/pretrained/base'
17
+ tokenizer = Wav2Vec2CTCTokenizer("./model-bin/finetune/vocab.json",
18
+ unk_token="<unk>",
19
+ pad_token="<pad>",
20
+ word_delimiter_token="|")
21
+
22
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(pre_trained_path)
23
+ processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
24
+
25
+ model = Wav2Vec2ForCTC.from_pretrained(
26
+ pre_trained_path,
27
+ gradient_checkpointing=True,
28
+ ctc_loss_reduction="mean",
29
+ pad_token_id=processor.tokenizer.pad_token_id,
30
+ )
31
+ model.freeze_feature_extractor()
32
+ else:
33
+ tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(checkpoint_path)
34
+
35
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(checkpoint_path)
36
+ processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
37
+
38
+ model = Wav2Vec2ForCTC.from_pretrained(
39
+ checkpoint_path,
40
+ gradient_checkpointing=True,
41
+ ctc_loss_reduction="mean",
42
+ pad_token_id=processor.tokenizer.pad_token_id,
43
+ )
44
+ # model.freeze_feature_extractor()
45
+
46
+ model_total_params = sum(p.numel() for p in model.parameters())
47
+ model_total_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
48
+ print(model)
49
+ print("model_total_params: {}\nmodel_total_params_trainable: {}".format(model_total_params,
50
+ model_total_params_trainable))
51
+ return model, processor
52
+
53
+
54
+ def prepare_dataset(batch, processor):
55
+ # check that all files have the correct sampling rate
56
+ assert (
57
+ len(set(batch["sampling_rate"])) == 1
58
+ ), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."
59
+
60
+ batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values
61
+
62
+ with processor.as_target_processor():
63
+ batch["labels"] = processor(batch["target_text"]).input_ids
64
+ return batch
65
+
66
+
67
+ def load_prepared_dataset(path, processor, cache_file_name):
68
+ dataset = load_from_disk(path)
69
+ processed_dataset = dataset.map(prepare_dataset,
70
+ remove_columns=dataset.column_names,
71
+ batch_size=8,
72
+ num_proc=8,
73
+ batched=True,
74
+ fn_kwargs={"processor": processor},
75
+ cache_file_name=cache_file_name)
76
+ return processed_dataset
77
+
78
+
79
+ # def get_train_dataset():
80
+ # for i in range()
81
+
82
+ if __name__ == "__main__":
83
+
84
+ checkpoint_path = "./model-bin/finetune/base/"
85
+ train_dataset_root_folder = './data-bin/train_dataset'
86
+ test_dataset_root_folder = './data-bin/test_dataset'
87
+ cache_processing_dataset_folder = './data-bin/cache/'
88
+ if not os.path.exists(cache_processing_dataset_folder):
89
+ os.makedirs(cache_processing_dataset_folder)
90
+ num_train_shards = len(glob.glob(os.path.join(train_dataset_root_folder, 'shard_*')))
91
+ num_test_shards = len(glob.glob(os.path.join(test_dataset_root_folder, 'shard_*')))
92
+ num_epochs = 20
93
+
94
+ training_args = TrainingArguments(
95
+ output_dir=checkpoint_path,
96
+ # fp16=True,
97
+ group_by_length=True,
98
+ per_device_train_batch_size=2,
99
+ per_device_eval_batch_size=2,
100
+ gradient_accumulation_steps=1,
101
+ num_train_epochs=1, # each epoch per shard data
102
+ logging_steps=1,
103
+ learning_rate=1e-4,
104
+ weight_decay=0.005,
105
+ warmup_steps=5000,
106
+ save_total_limit=2,
107
+ ignore_data_skip=True,
108
+ logging_dir=os.path.join(checkpoint_path, 'log'),
109
+ metric_for_best_model='wer',
110
+ save_strategy="epoch",
111
+ evaluation_strategy="epoch",
112
+ # save_steps=5,
113
+ # eval_steps=5,
114
+ )
115
+
116
+ # PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME))
117
+ last_checkpoint_path = None
118
+ last_epoch_idx = 0
119
+ if os.path.exists(checkpoint_path):
120
+ last_checkpoint_path = get_last_checkpoint(checkpoint_path)
121
+ if last_checkpoint_path is not None:
122
+ with open(os.path.join(last_checkpoint_path, "trainer_state.json"), 'r', encoding='utf-8') as file:
123
+ trainer_state = json.load(file)
124
+ last_epoch_idx = int(trainer_state['epoch'])
125
+
126
+ w2v_ctc_model, w2v_ctc_processor = load_pretrained_model()
127
+ data_collator = DataCollatorCTCWithPadding(processor=w2v_ctc_processor, padding=True)
128
+
129
+ for epoch_idx in range(last_epoch_idx, num_epochs):
130
+ # loop over training shards
131
+ train_dataset_shard_idx = epoch_idx % num_train_shards
132
+ # Get test shard depend on train shard id
133
+ test_dataset_shard_idx = round(train_dataset_shard_idx / (num_train_shards / num_test_shards))
134
+ num_test_sub_shard = 1000 # Split test shard into subset. Default is 8
135
+ idx_sub_shard = train_dataset_shard_idx % num_test_sub_shard # loop over test shard subset
136
+
137
+ # load train shard
138
+ train_dataset = load_prepared_dataset(os.path.join(train_dataset_root_folder,
139
+ 'shard_{}'.format(train_dataset_shard_idx)),
140
+ w2v_ctc_processor,
141
+ cache_file_name=os.path.join(cache_processing_dataset_folder,
142
+ 'cache-train-shard-{}.arrow'.format(
143
+ train_dataset_shard_idx))
144
+ ).shard(1000, 0) # Remove shard split when train
145
+ # load test shard subset
146
+ test_dataset = load_prepared_dataset(os.path.join(test_dataset_root_folder,
147
+ 'shard_{}'.format(test_dataset_shard_idx)),
148
+ w2v_ctc_processor,
149
+ cache_file_name=os.path.join(cache_processing_dataset_folder,
150
+ 'cache-test-shard-{}.arrow'.format(
151
+ test_dataset_shard_idx))
152
+ ).shard(num_test_sub_shard, idx_sub_shard)
153
+
154
+ # Init trainer
155
+ trainer = Trainer(
156
+ model=w2v_ctc_model,
157
+ data_collator=data_collator,
158
+ args=training_args,
159
+ compute_metrics=compute_metrics_fn(w2v_ctc_processor),
160
+ train_dataset=train_dataset,
161
+ eval_dataset=test_dataset,
162
+ tokenizer=w2v_ctc_processor.feature_extractor
163
+ )
164
+ # Manual add num_train_epochs because each epoch loop over a shard
165
+ training_args.num_train_epochs = epoch_idx + 1
166
+
167
+ logging.get_logger().info('Train shard idx: {}'.format(train_dataset_shard_idx))
168
+ logging.get_logger().info('Valid shard idx: {} sub_shard: {}'.format(test_dataset_shard_idx, idx_sub_shard))
169
+
170
+ if last_checkpoint_path is not None:
171
+ # start train from a checkpoint if exist
172
+ trainer.train(resume_from_checkpoint=True)
173
+ else:
174
+ # train from pre-trained wav2vec2 checkpoint
175
+ trainer.train()
176
+ last_checkpoint_path = get_last_checkpoint(checkpoint_path)
177
+
178
+ # Clear cache file to free disk
179
+ # test_dataset.cleanup_cache_files()
180
+ # train_dataset.cleanup_cache_files()
metric_utils.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from datasets import load_metric
3
+
4
+ wer_metric = load_metric("./model-bin/metrics/wer")
5
+
6
+
7
+ # print(wer_metric)
8
+
9
+
10
+ def compute_metrics_fn(processor):
11
+ def compute(pred):
12
+ pred_logits = pred.predictions
13
+ pred_ids = np.argmax(pred_logits, axis=-1)
14
+
15
+ pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
16
+
17
+ pred_str = processor.batch_decode(pred_ids)
18
+ # we do not want to group tokens when computing the metrics
19
+ label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
20
+
21
+ wer = wer_metric.compute(predictions=pred_str, references=label_str)
22
+
23
+ return {"wer": wer}
24
+
25
+ return compute
model-bin/metrics/wer/wer.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Datasets Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Word Error Ratio (WER) metric. """
16
+
17
+ from jiwer import compute_measures
18
+
19
+ import datasets
20
+
21
+ _CITATION = """\
22
+ @inproceedings{inproceedings,
23
+ author = {Morris, Andrew and Maier, Viktoria and Green, Phil},
24
+ year = {2004},
25
+ month = {01},
26
+ pages = {},
27
+ title = {From WER and RIL to MER and WIL: improved evaluation measures for connected speech recognition.}
28
+ }
29
+ """
30
+
31
+ _DESCRIPTION = """\
32
+ Word error rate (WER) is a common metric of the performance of an automatic speech recognition system.
33
+
34
+ The general difficulty of measuring performance lies in the fact that the recognized word sequence can have a different length from the reference word sequence (supposedly the correct one). The WER is derived from the Levenshtein distance, working at the word level instead of the phoneme level. The WER is a valuable tool for comparing different systems as well as for evaluating improvements within one system. This kind of measurement, however, provides no details on the nature of translation errors and further work is therefore required to identify the main source(s) of error and to focus any research effort.
35
+
36
+ This problem is solved by first aligning the recognized word sequence with the reference (spoken) word sequence using dynamic string alignment. Examination of this issue is seen through a theory called the power law that states the correlation between perplexity and word error rate.
37
+
38
+ Word error rate can then be computed as:
39
+
40
+ WER = (S + D + I) / N = (S + D + I) / (S + D + C)
41
+
42
+ where
43
+
44
+ S is the number of substitutions,
45
+ D is the number of deletions,
46
+ I is the number of insertions,
47
+ C is the number of correct words,
48
+ N is the number of words in the reference (N=S+D+C).
49
+
50
+ WER's output is always a number between 0 and 1. This value indicates the percentage of words that were incorrectly predicted. The lower the value, the better the
51
+ performance of the ASR system with a WER of 0 being a perfect score.
52
+ """
53
+
54
+ _KWARGS_DESCRIPTION = """
55
+ Compute WER score of transcribed segments against references.
56
+
57
+ Args:
58
+ references: List of references for each speech input.
59
+ predictions: List of transcriptions to score.
60
+ concatenate_texts (bool, default=False): Whether to concatenate all input texts or compute WER iteratively.
61
+
62
+ Returns:
63
+ (float): the word error rate
64
+
65
+ Examples:
66
+
67
+ >>> predictions = ["this is the prediction", "there is an other sample"]
68
+ >>> references = ["this is the reference", "there is another one"]
69
+ >>> wer = datasets.load_metric("wer")
70
+ >>> wer_score = wer.compute(predictions=predictions, references=references)
71
+ >>> print(wer_score)
72
+ 0.5
73
+ """
74
+
75
+
76
+ @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
77
+ class WER(datasets.Metric):
78
+ def _info(self):
79
+ return datasets.MetricInfo(
80
+ description=_DESCRIPTION,
81
+ citation=_CITATION,
82
+ inputs_description=_KWARGS_DESCRIPTION,
83
+ features=datasets.Features(
84
+ {
85
+ "predictions": datasets.Value("string", id="sequence"),
86
+ "references": datasets.Value("string", id="sequence"),
87
+ }
88
+ ),
89
+ codebase_urls=["https://github.com/jitsi/jiwer/"],
90
+ reference_urls=[
91
+ "https://en.wikipedia.org/wiki/Word_error_rate",
92
+ ],
93
+ )
94
+
95
+ def _compute(self, predictions=None, references=None, concatenate_texts=False):
96
+ if concatenate_texts:
97
+ return compute_measures(references, predictions)["wer"]
98
+ else:
99
+ incorrect = 0
100
+ total = 0
101
+ for prediction, reference in zip(predictions, references):
102
+ measures = compute_measures(reference, prediction)
103
+ incorrect += measures["substitutions"] + measures["deletions"] + measures["insertions"]
104
+ total += measures["substitutions"] + measures["deletions"] + measures["hits"]
105
+ return incorrect / total
model-bin/metrics/wer/wer.py.lock ADDED
File without changes
model-bin/pretrained/base/config.json ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "nguyenvulebinh/wav2vec2_vi",
3
+ "activation_dropout": 0.1,
4
+ "apply_spec_augment": true,
5
+ "architectures": [
6
+ "Wav2Vec2ForPreTraining"
7
+ ],
8
+ "attention_dropout": 0.1,
9
+ "bos_token_id": 1,
10
+ "codevector_dim": 256,
11
+ "contrastive_logits_temperature": 0.1,
12
+ "conv_bias": false,
13
+ "conv_dim": [
14
+ 512,
15
+ 512,
16
+ 512,
17
+ 512,
18
+ 512,
19
+ 512,
20
+ 512
21
+ ],
22
+ "conv_kernel": [
23
+ 10,
24
+ 3,
25
+ 3,
26
+ 3,
27
+ 3,
28
+ 2,
29
+ 2
30
+ ],
31
+ "conv_stride": [
32
+ 5,
33
+ 2,
34
+ 2,
35
+ 2,
36
+ 2,
37
+ 2,
38
+ 2
39
+ ],
40
+ "ctc_loss_reduction": "sum",
41
+ "ctc_zero_infinity": false,
42
+ "diversity_loss_weight": 0.1,
43
+ "do_stable_layer_norm": false,
44
+ "eos_token_id": 2,
45
+ "feat_extract_activation": "gelu",
46
+ "feat_extract_dropout": 0.0,
47
+ "feat_extract_norm": "group",
48
+ "feat_proj_dropout": 0.1,
49
+ "feat_quantizer_dropout": 0.0,
50
+ "final_dropout": 0.1,
51
+ "gradient_checkpointing": false,
52
+ "hidden_act": "gelu",
53
+ "hidden_dropout": 0.1,
54
+ "hidden_dropout_prob": 0.1,
55
+ "hidden_size": 768,
56
+ "initializer_range": 0.02,
57
+ "intermediate_size": 3072,
58
+ "layer_norm_eps": 1e-05,
59
+ "layerdrop": 0.1,
60
+ "mask_feature_length": 10,
61
+ "mask_feature_prob": 0.0,
62
+ "mask_time_length": 10,
63
+ "mask_time_prob": 0.05,
64
+ "model_type": "wav2vec2",
65
+ "num_attention_heads": 12,
66
+ "num_codevector_groups": 2,
67
+ "num_codevectors_per_group": 320,
68
+ "num_conv_pos_embedding_groups": 16,
69
+ "num_conv_pos_embeddings": 128,
70
+ "num_feat_extract_layers": 7,
71
+ "num_hidden_layers": 12,
72
+ "num_negatives": 100,
73
+ "pad_token_id": 0,
74
+ "proj_codevector_dim": 256,
75
+ "torch_dtype": "float32",
76
+ "transformers_version": "4.9.1",
77
+ "vocab_size": 110
78
+ }
model-bin/pretrained/base/preprocessor_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "feature_size": 1,
4
+ "padding_side": "right",
5
+ "padding_value": 0.0,
6
+ "return_attention_mask": false,
7
+ "sampling_rate": 16000
8
+ }
model-bin/pretrained/base/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b36355988e4d1f94d070ef677ab4d304bce440af0c3dd7bd1c98e295e907f09
3
+ size 380261837
requirments.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ soundfile
2
+ transformers==4.9.2
3
+ torch==1.9.0
4
+ datasets==1.11.0
5
+ jiwer
6
+ tensorboard