nguyenvulebinh
commited on
Commit
•
778e524
1
Parent(s):
1155872
add init code train
Browse files- .gitignore +258 -0
- data_handler.py +68 -0
- main.py +180 -0
- metric_utils.py +25 -0
- model-bin/metrics/wer/wer.py +105 -0
- model-bin/metrics/wer/wer.py.lock +0 -0
- model-bin/pretrained/base/config.json +78 -0
- model-bin/pretrained/base/preprocessor_config.json +8 -0
- model-bin/pretrained/base/pytorch_model.bin +3 -0
- requirments.txt +6 -0
.gitignore
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Created by https://www.toptal.com/developers/gitignore/api/pycharm,python
|
3 |
+
# Edit at https://www.toptal.com/developers/gitignore?templates=pycharm,python
|
4 |
+
|
5 |
+
### PyCharm ###
|
6 |
+
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
|
7 |
+
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
8 |
+
|
9 |
+
# User-specific stuff
|
10 |
+
.idea/**/workspace.xml
|
11 |
+
.idea/**/tasks.xml
|
12 |
+
.idea/**/usage.statistics.xml
|
13 |
+
.idea/**/dictionaries
|
14 |
+
.idea/**/shelf
|
15 |
+
|
16 |
+
# AWS User-specific
|
17 |
+
.idea/**/aws.xml
|
18 |
+
|
19 |
+
# Generated files
|
20 |
+
.idea/**/contentModel.xml
|
21 |
+
|
22 |
+
# Sensitive or high-churn files
|
23 |
+
.idea/**/dataSources/
|
24 |
+
.idea/**/dataSources.ids
|
25 |
+
.idea/**/dataSources.local.xml
|
26 |
+
.idea/**/sqlDataSources.xml
|
27 |
+
.idea/**/dynamic.xml
|
28 |
+
.idea/**/uiDesigner.xml
|
29 |
+
.idea/**/dbnavigator.xml
|
30 |
+
|
31 |
+
# Gradle
|
32 |
+
.idea/**/gradle.xml
|
33 |
+
.idea/**/libraries
|
34 |
+
|
35 |
+
# Gradle and Maven with auto-import
|
36 |
+
# When using Gradle or Maven with auto-import, you should exclude module files,
|
37 |
+
# since they will be recreated, and may cause churn. Uncomment if using
|
38 |
+
# auto-import.
|
39 |
+
# .idea/artifacts
|
40 |
+
# .idea/compiler.xml
|
41 |
+
# .idea/jarRepositories.xml
|
42 |
+
# .idea/modules.xml
|
43 |
+
# .idea/*.iml
|
44 |
+
# .idea/modules
|
45 |
+
# *.iml
|
46 |
+
# *.ipr
|
47 |
+
|
48 |
+
# CMake
|
49 |
+
cmake-build-*/
|
50 |
+
|
51 |
+
# Mongo Explorer plugin
|
52 |
+
.idea/**/mongoSettings.xml
|
53 |
+
|
54 |
+
# File-based project format
|
55 |
+
*.iws
|
56 |
+
|
57 |
+
# IntelliJ
|
58 |
+
out/
|
59 |
+
|
60 |
+
# mpeltonen/sbt-idea plugin
|
61 |
+
.idea_modules/
|
62 |
+
|
63 |
+
# JIRA plugin
|
64 |
+
atlassian-ide-plugin.xml
|
65 |
+
|
66 |
+
# Cursive Clojure plugin
|
67 |
+
.idea/replstate.xml
|
68 |
+
|
69 |
+
# Crashlytics plugin (for Android Studio and IntelliJ)
|
70 |
+
com_crashlytics_export_strings.xml
|
71 |
+
crashlytics.properties
|
72 |
+
crashlytics-build.properties
|
73 |
+
fabric.properties
|
74 |
+
|
75 |
+
# Editor-based Rest Client
|
76 |
+
.idea/httpRequests
|
77 |
+
|
78 |
+
# Android studio 3.1+ serialized cache file
|
79 |
+
.idea/caches/build_file_checksums.ser
|
80 |
+
|
81 |
+
### PyCharm Patch ###
|
82 |
+
# Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721
|
83 |
+
|
84 |
+
# *.iml
|
85 |
+
# modules.xml
|
86 |
+
# .idea/misc.xml
|
87 |
+
# *.ipr
|
88 |
+
|
89 |
+
# Sonarlint plugin
|
90 |
+
# https://plugins.jetbrains.com/plugin/7973-sonarlint
|
91 |
+
.idea/**/sonarlint/
|
92 |
+
|
93 |
+
# SonarQube Plugin
|
94 |
+
# https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin
|
95 |
+
.idea/**/sonarIssues.xml
|
96 |
+
|
97 |
+
# Markdown Navigator plugin
|
98 |
+
# https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced
|
99 |
+
.idea/**/markdown-navigator.xml
|
100 |
+
.idea/**/markdown-navigator-enh.xml
|
101 |
+
.idea/**/markdown-navigator/
|
102 |
+
|
103 |
+
# Cache file creation bug
|
104 |
+
# See https://youtrack.jetbrains.com/issue/JBR-2257
|
105 |
+
.idea/$CACHE_FILE$
|
106 |
+
|
107 |
+
# CodeStream plugin
|
108 |
+
# https://plugins.jetbrains.com/plugin/12206-codestream
|
109 |
+
.idea/codestream.xml
|
110 |
+
|
111 |
+
### Python ###
|
112 |
+
# Byte-compiled / optimized / DLL files
|
113 |
+
__pycache__/
|
114 |
+
*.py[cod]
|
115 |
+
*$py.class
|
116 |
+
|
117 |
+
# C extensions
|
118 |
+
*.so
|
119 |
+
|
120 |
+
# Distribution / packaging
|
121 |
+
.Python
|
122 |
+
build/
|
123 |
+
develop-eggs/
|
124 |
+
dist/
|
125 |
+
downloads/
|
126 |
+
eggs/
|
127 |
+
.eggs/
|
128 |
+
lib/
|
129 |
+
lib64/
|
130 |
+
parts/
|
131 |
+
sdist/
|
132 |
+
var/
|
133 |
+
wheels/
|
134 |
+
share/python-wheels/
|
135 |
+
*.egg-info/
|
136 |
+
.installed.cfg
|
137 |
+
*.egg
|
138 |
+
MANIFEST
|
139 |
+
|
140 |
+
# PyInstaller
|
141 |
+
# Usually these files are written by a python script from a template
|
142 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
143 |
+
*.manifest
|
144 |
+
*.spec
|
145 |
+
|
146 |
+
# Installer logs
|
147 |
+
pip-log.txt
|
148 |
+
pip-delete-this-directory.txt
|
149 |
+
|
150 |
+
# Unit test / coverage reports
|
151 |
+
htmlcov/
|
152 |
+
.tox/
|
153 |
+
.nox/
|
154 |
+
.coverage
|
155 |
+
.coverage.*
|
156 |
+
.cache
|
157 |
+
nosetests.xml
|
158 |
+
coverage.xml
|
159 |
+
*.cover
|
160 |
+
*.py,cover
|
161 |
+
.hypothesis/
|
162 |
+
.pytest_cache/
|
163 |
+
cover/
|
164 |
+
|
165 |
+
# Translations
|
166 |
+
*.mo
|
167 |
+
*.pot
|
168 |
+
|
169 |
+
# Django stuff:
|
170 |
+
*.log
|
171 |
+
local_settings.py
|
172 |
+
db.sqlite3
|
173 |
+
db.sqlite3-journal
|
174 |
+
|
175 |
+
# Flask stuff:
|
176 |
+
instance/
|
177 |
+
.webassets-cache
|
178 |
+
|
179 |
+
# Scrapy stuff:
|
180 |
+
.scrapy
|
181 |
+
|
182 |
+
# Sphinx documentation
|
183 |
+
docs/_build/
|
184 |
+
|
185 |
+
# PyBuilder
|
186 |
+
.pybuilder/
|
187 |
+
target/
|
188 |
+
|
189 |
+
# Jupyter Notebook
|
190 |
+
.ipynb_checkpoints
|
191 |
+
|
192 |
+
# IPython
|
193 |
+
profile_default/
|
194 |
+
ipython_config.py
|
195 |
+
|
196 |
+
# pyenv
|
197 |
+
# For a library or package, you might want to ignore these files since the code is
|
198 |
+
# intended to run in multiple environments; otherwise, check them in:
|
199 |
+
# .python-version
|
200 |
+
|
201 |
+
# pipenv
|
202 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
203 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
204 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
205 |
+
# install all needed dependencies.
|
206 |
+
#Pipfile.lock
|
207 |
+
|
208 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
209 |
+
__pypackages__/
|
210 |
+
|
211 |
+
# Celery stuff
|
212 |
+
celerybeat-schedule
|
213 |
+
celerybeat.pid
|
214 |
+
|
215 |
+
# SageMath parsed files
|
216 |
+
*.sage.py
|
217 |
+
|
218 |
+
# Environments
|
219 |
+
.env
|
220 |
+
.venv
|
221 |
+
env/
|
222 |
+
venv/
|
223 |
+
ENV/
|
224 |
+
env.bak/
|
225 |
+
venv.bak/
|
226 |
+
|
227 |
+
# Spyder project settings
|
228 |
+
.spyderproject
|
229 |
+
.spyproject
|
230 |
+
|
231 |
+
# Rope project settings
|
232 |
+
.ropeproject
|
233 |
+
|
234 |
+
# mkdocs documentation
|
235 |
+
/site
|
236 |
+
|
237 |
+
# mypy
|
238 |
+
.mypy_cache/
|
239 |
+
.dmypy.json
|
240 |
+
dmypy.json
|
241 |
+
|
242 |
+
# Pyre type checker
|
243 |
+
.pyre/
|
244 |
+
|
245 |
+
# pytype static type analyzer
|
246 |
+
.pytype/
|
247 |
+
|
248 |
+
# Cython debug symbols
|
249 |
+
cython_debug/
|
250 |
+
|
251 |
+
# End of https://www.toptal.com/developers/gitignore/api/pycharm,python
|
252 |
+
|
253 |
+
|
254 |
+
data-bin/
|
255 |
+
|
256 |
+
.DS_Store
|
257 |
+
|
258 |
+
.idea/
|
data_handler.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
from typing import Any, Dict, List, Optional, Union
|
5 |
+
from transformers import Wav2Vec2Processor
|
6 |
+
|
7 |
+
|
8 |
+
@dataclass
|
9 |
+
class DataCollatorCTCWithPadding:
|
10 |
+
"""
|
11 |
+
Data collator that will dynamically pad the inputs received.
|
12 |
+
Args:
|
13 |
+
processor (:class:`~transformers.Wav2Vec2Processor`)
|
14 |
+
The processor used for proccessing the data.
|
15 |
+
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
16 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
17 |
+
among:
|
18 |
+
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
19 |
+
sequence if provided).
|
20 |
+
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
21 |
+
maximum acceptable input length for the model if that argument is not provided.
|
22 |
+
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
23 |
+
different lengths).
|
24 |
+
max_length (:obj:`int`, `optional`):
|
25 |
+
Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
|
26 |
+
max_length_labels (:obj:`int`, `optional`):
|
27 |
+
Maximum length of the ``labels`` returned list and optionally padding length (see above).
|
28 |
+
pad_to_multiple_of (:obj:`int`, `optional`):
|
29 |
+
If set will pad the sequence to a multiple of the provided value.
|
30 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
31 |
+
7.5 (Volta).
|
32 |
+
"""
|
33 |
+
|
34 |
+
processor: Wav2Vec2Processor
|
35 |
+
padding: Union[bool, str] = True
|
36 |
+
max_length: Optional[int] = None
|
37 |
+
max_length_labels: Optional[int] = None
|
38 |
+
pad_to_multiple_of: Optional[int] = None
|
39 |
+
pad_to_multiple_of_labels: Optional[int] = None
|
40 |
+
|
41 |
+
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
42 |
+
# split inputs and labels since they have to be of different lenghts and need
|
43 |
+
# different padding methods
|
44 |
+
input_features = [{"input_values": feature["input_values"]} for feature in features]
|
45 |
+
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
46 |
+
|
47 |
+
batch = self.processor.pad(
|
48 |
+
input_features,
|
49 |
+
padding=self.padding,
|
50 |
+
max_length=self.max_length,
|
51 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
52 |
+
return_tensors="pt",
|
53 |
+
)
|
54 |
+
with self.processor.as_target_processor():
|
55 |
+
labels_batch = self.processor.pad(
|
56 |
+
label_features,
|
57 |
+
padding=self.padding,
|
58 |
+
max_length=self.max_length_labels,
|
59 |
+
pad_to_multiple_of=self.pad_to_multiple_of_labels,
|
60 |
+
return_tensors="pt",
|
61 |
+
)
|
62 |
+
|
63 |
+
# replace padding with -100 to ignore loss correctly
|
64 |
+
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
|
65 |
+
|
66 |
+
batch["labels"] = labels
|
67 |
+
|
68 |
+
return batch
|
main.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
|
2 |
+
from datasets import load_from_disk
|
3 |
+
from data_handler import DataCollatorCTCWithPadding
|
4 |
+
from transformers import TrainingArguments
|
5 |
+
from transformers import Trainer, logging
|
6 |
+
from metric_utils import compute_metrics_fn
|
7 |
+
from transformers.trainer_utils import get_last_checkpoint
|
8 |
+
import json, random
|
9 |
+
import os, glob
|
10 |
+
|
11 |
+
logging.set_verbosity_info()
|
12 |
+
|
13 |
+
|
14 |
+
def load_pretrained_model(checkpoint_path=None):
|
15 |
+
if checkpoint_path is None:
|
16 |
+
pre_trained_path = './model-bin/pretrained/base'
|
17 |
+
tokenizer = Wav2Vec2CTCTokenizer("./model-bin/finetune/vocab.json",
|
18 |
+
unk_token="<unk>",
|
19 |
+
pad_token="<pad>",
|
20 |
+
word_delimiter_token="|")
|
21 |
+
|
22 |
+
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(pre_trained_path)
|
23 |
+
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
24 |
+
|
25 |
+
model = Wav2Vec2ForCTC.from_pretrained(
|
26 |
+
pre_trained_path,
|
27 |
+
gradient_checkpointing=True,
|
28 |
+
ctc_loss_reduction="mean",
|
29 |
+
pad_token_id=processor.tokenizer.pad_token_id,
|
30 |
+
)
|
31 |
+
model.freeze_feature_extractor()
|
32 |
+
else:
|
33 |
+
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(checkpoint_path)
|
34 |
+
|
35 |
+
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(checkpoint_path)
|
36 |
+
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
37 |
+
|
38 |
+
model = Wav2Vec2ForCTC.from_pretrained(
|
39 |
+
checkpoint_path,
|
40 |
+
gradient_checkpointing=True,
|
41 |
+
ctc_loss_reduction="mean",
|
42 |
+
pad_token_id=processor.tokenizer.pad_token_id,
|
43 |
+
)
|
44 |
+
# model.freeze_feature_extractor()
|
45 |
+
|
46 |
+
model_total_params = sum(p.numel() for p in model.parameters())
|
47 |
+
model_total_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
48 |
+
print(model)
|
49 |
+
print("model_total_params: {}\nmodel_total_params_trainable: {}".format(model_total_params,
|
50 |
+
model_total_params_trainable))
|
51 |
+
return model, processor
|
52 |
+
|
53 |
+
|
54 |
+
def prepare_dataset(batch, processor):
|
55 |
+
# check that all files have the correct sampling rate
|
56 |
+
assert (
|
57 |
+
len(set(batch["sampling_rate"])) == 1
|
58 |
+
), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."
|
59 |
+
|
60 |
+
batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values
|
61 |
+
|
62 |
+
with processor.as_target_processor():
|
63 |
+
batch["labels"] = processor(batch["target_text"]).input_ids
|
64 |
+
return batch
|
65 |
+
|
66 |
+
|
67 |
+
def load_prepared_dataset(path, processor, cache_file_name):
|
68 |
+
dataset = load_from_disk(path)
|
69 |
+
processed_dataset = dataset.map(prepare_dataset,
|
70 |
+
remove_columns=dataset.column_names,
|
71 |
+
batch_size=8,
|
72 |
+
num_proc=8,
|
73 |
+
batched=True,
|
74 |
+
fn_kwargs={"processor": processor},
|
75 |
+
cache_file_name=cache_file_name)
|
76 |
+
return processed_dataset
|
77 |
+
|
78 |
+
|
79 |
+
# def get_train_dataset():
|
80 |
+
# for i in range()
|
81 |
+
|
82 |
+
if __name__ == "__main__":
|
83 |
+
|
84 |
+
checkpoint_path = "./model-bin/finetune/base/"
|
85 |
+
train_dataset_root_folder = './data-bin/train_dataset'
|
86 |
+
test_dataset_root_folder = './data-bin/test_dataset'
|
87 |
+
cache_processing_dataset_folder = './data-bin/cache/'
|
88 |
+
if not os.path.exists(cache_processing_dataset_folder):
|
89 |
+
os.makedirs(cache_processing_dataset_folder)
|
90 |
+
num_train_shards = len(glob.glob(os.path.join(train_dataset_root_folder, 'shard_*')))
|
91 |
+
num_test_shards = len(glob.glob(os.path.join(test_dataset_root_folder, 'shard_*')))
|
92 |
+
num_epochs = 20
|
93 |
+
|
94 |
+
training_args = TrainingArguments(
|
95 |
+
output_dir=checkpoint_path,
|
96 |
+
# fp16=True,
|
97 |
+
group_by_length=True,
|
98 |
+
per_device_train_batch_size=2,
|
99 |
+
per_device_eval_batch_size=2,
|
100 |
+
gradient_accumulation_steps=1,
|
101 |
+
num_train_epochs=1, # each epoch per shard data
|
102 |
+
logging_steps=1,
|
103 |
+
learning_rate=1e-4,
|
104 |
+
weight_decay=0.005,
|
105 |
+
warmup_steps=5000,
|
106 |
+
save_total_limit=2,
|
107 |
+
ignore_data_skip=True,
|
108 |
+
logging_dir=os.path.join(checkpoint_path, 'log'),
|
109 |
+
metric_for_best_model='wer',
|
110 |
+
save_strategy="epoch",
|
111 |
+
evaluation_strategy="epoch",
|
112 |
+
# save_steps=5,
|
113 |
+
# eval_steps=5,
|
114 |
+
)
|
115 |
+
|
116 |
+
# PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME))
|
117 |
+
last_checkpoint_path = None
|
118 |
+
last_epoch_idx = 0
|
119 |
+
if os.path.exists(checkpoint_path):
|
120 |
+
last_checkpoint_path = get_last_checkpoint(checkpoint_path)
|
121 |
+
if last_checkpoint_path is not None:
|
122 |
+
with open(os.path.join(last_checkpoint_path, "trainer_state.json"), 'r', encoding='utf-8') as file:
|
123 |
+
trainer_state = json.load(file)
|
124 |
+
last_epoch_idx = int(trainer_state['epoch'])
|
125 |
+
|
126 |
+
w2v_ctc_model, w2v_ctc_processor = load_pretrained_model()
|
127 |
+
data_collator = DataCollatorCTCWithPadding(processor=w2v_ctc_processor, padding=True)
|
128 |
+
|
129 |
+
for epoch_idx in range(last_epoch_idx, num_epochs):
|
130 |
+
# loop over training shards
|
131 |
+
train_dataset_shard_idx = epoch_idx % num_train_shards
|
132 |
+
# Get test shard depend on train shard id
|
133 |
+
test_dataset_shard_idx = round(train_dataset_shard_idx / (num_train_shards / num_test_shards))
|
134 |
+
num_test_sub_shard = 1000 # Split test shard into subset. Default is 8
|
135 |
+
idx_sub_shard = train_dataset_shard_idx % num_test_sub_shard # loop over test shard subset
|
136 |
+
|
137 |
+
# load train shard
|
138 |
+
train_dataset = load_prepared_dataset(os.path.join(train_dataset_root_folder,
|
139 |
+
'shard_{}'.format(train_dataset_shard_idx)),
|
140 |
+
w2v_ctc_processor,
|
141 |
+
cache_file_name=os.path.join(cache_processing_dataset_folder,
|
142 |
+
'cache-train-shard-{}.arrow'.format(
|
143 |
+
train_dataset_shard_idx))
|
144 |
+
).shard(1000, 0) # Remove shard split when train
|
145 |
+
# load test shard subset
|
146 |
+
test_dataset = load_prepared_dataset(os.path.join(test_dataset_root_folder,
|
147 |
+
'shard_{}'.format(test_dataset_shard_idx)),
|
148 |
+
w2v_ctc_processor,
|
149 |
+
cache_file_name=os.path.join(cache_processing_dataset_folder,
|
150 |
+
'cache-test-shard-{}.arrow'.format(
|
151 |
+
test_dataset_shard_idx))
|
152 |
+
).shard(num_test_sub_shard, idx_sub_shard)
|
153 |
+
|
154 |
+
# Init trainer
|
155 |
+
trainer = Trainer(
|
156 |
+
model=w2v_ctc_model,
|
157 |
+
data_collator=data_collator,
|
158 |
+
args=training_args,
|
159 |
+
compute_metrics=compute_metrics_fn(w2v_ctc_processor),
|
160 |
+
train_dataset=train_dataset,
|
161 |
+
eval_dataset=test_dataset,
|
162 |
+
tokenizer=w2v_ctc_processor.feature_extractor
|
163 |
+
)
|
164 |
+
# Manual add num_train_epochs because each epoch loop over a shard
|
165 |
+
training_args.num_train_epochs = epoch_idx + 1
|
166 |
+
|
167 |
+
logging.get_logger().info('Train shard idx: {}'.format(train_dataset_shard_idx))
|
168 |
+
logging.get_logger().info('Valid shard idx: {} sub_shard: {}'.format(test_dataset_shard_idx, idx_sub_shard))
|
169 |
+
|
170 |
+
if last_checkpoint_path is not None:
|
171 |
+
# start train from a checkpoint if exist
|
172 |
+
trainer.train(resume_from_checkpoint=True)
|
173 |
+
else:
|
174 |
+
# train from pre-trained wav2vec2 checkpoint
|
175 |
+
trainer.train()
|
176 |
+
last_checkpoint_path = get_last_checkpoint(checkpoint_path)
|
177 |
+
|
178 |
+
# Clear cache file to free disk
|
179 |
+
# test_dataset.cleanup_cache_files()
|
180 |
+
# train_dataset.cleanup_cache_files()
|
metric_utils.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from datasets import load_metric
|
3 |
+
|
4 |
+
wer_metric = load_metric("./model-bin/metrics/wer")
|
5 |
+
|
6 |
+
|
7 |
+
# print(wer_metric)
|
8 |
+
|
9 |
+
|
10 |
+
def compute_metrics_fn(processor):
|
11 |
+
def compute(pred):
|
12 |
+
pred_logits = pred.predictions
|
13 |
+
pred_ids = np.argmax(pred_logits, axis=-1)
|
14 |
+
|
15 |
+
pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
|
16 |
+
|
17 |
+
pred_str = processor.batch_decode(pred_ids)
|
18 |
+
# we do not want to group tokens when computing the metrics
|
19 |
+
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
|
20 |
+
|
21 |
+
wer = wer_metric.compute(predictions=pred_str, references=label_str)
|
22 |
+
|
23 |
+
return {"wer": wer}
|
24 |
+
|
25 |
+
return compute
|
model-bin/metrics/wer/wer.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The HuggingFace Datasets Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" Word Error Ratio (WER) metric. """
|
16 |
+
|
17 |
+
from jiwer import compute_measures
|
18 |
+
|
19 |
+
import datasets
|
20 |
+
|
21 |
+
_CITATION = """\
|
22 |
+
@inproceedings{inproceedings,
|
23 |
+
author = {Morris, Andrew and Maier, Viktoria and Green, Phil},
|
24 |
+
year = {2004},
|
25 |
+
month = {01},
|
26 |
+
pages = {},
|
27 |
+
title = {From WER and RIL to MER and WIL: improved evaluation measures for connected speech recognition.}
|
28 |
+
}
|
29 |
+
"""
|
30 |
+
|
31 |
+
_DESCRIPTION = """\
|
32 |
+
Word error rate (WER) is a common metric of the performance of an automatic speech recognition system.
|
33 |
+
|
34 |
+
The general difficulty of measuring performance lies in the fact that the recognized word sequence can have a different length from the reference word sequence (supposedly the correct one). The WER is derived from the Levenshtein distance, working at the word level instead of the phoneme level. The WER is a valuable tool for comparing different systems as well as for evaluating improvements within one system. This kind of measurement, however, provides no details on the nature of translation errors and further work is therefore required to identify the main source(s) of error and to focus any research effort.
|
35 |
+
|
36 |
+
This problem is solved by first aligning the recognized word sequence with the reference (spoken) word sequence using dynamic string alignment. Examination of this issue is seen through a theory called the power law that states the correlation between perplexity and word error rate.
|
37 |
+
|
38 |
+
Word error rate can then be computed as:
|
39 |
+
|
40 |
+
WER = (S + D + I) / N = (S + D + I) / (S + D + C)
|
41 |
+
|
42 |
+
where
|
43 |
+
|
44 |
+
S is the number of substitutions,
|
45 |
+
D is the number of deletions,
|
46 |
+
I is the number of insertions,
|
47 |
+
C is the number of correct words,
|
48 |
+
N is the number of words in the reference (N=S+D+C).
|
49 |
+
|
50 |
+
WER's output is always a number between 0 and 1. This value indicates the percentage of words that were incorrectly predicted. The lower the value, the better the
|
51 |
+
performance of the ASR system with a WER of 0 being a perfect score.
|
52 |
+
"""
|
53 |
+
|
54 |
+
_KWARGS_DESCRIPTION = """
|
55 |
+
Compute WER score of transcribed segments against references.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
references: List of references for each speech input.
|
59 |
+
predictions: List of transcriptions to score.
|
60 |
+
concatenate_texts (bool, default=False): Whether to concatenate all input texts or compute WER iteratively.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
(float): the word error rate
|
64 |
+
|
65 |
+
Examples:
|
66 |
+
|
67 |
+
>>> predictions = ["this is the prediction", "there is an other sample"]
|
68 |
+
>>> references = ["this is the reference", "there is another one"]
|
69 |
+
>>> wer = datasets.load_metric("wer")
|
70 |
+
>>> wer_score = wer.compute(predictions=predictions, references=references)
|
71 |
+
>>> print(wer_score)
|
72 |
+
0.5
|
73 |
+
"""
|
74 |
+
|
75 |
+
|
76 |
+
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
77 |
+
class WER(datasets.Metric):
|
78 |
+
def _info(self):
|
79 |
+
return datasets.MetricInfo(
|
80 |
+
description=_DESCRIPTION,
|
81 |
+
citation=_CITATION,
|
82 |
+
inputs_description=_KWARGS_DESCRIPTION,
|
83 |
+
features=datasets.Features(
|
84 |
+
{
|
85 |
+
"predictions": datasets.Value("string", id="sequence"),
|
86 |
+
"references": datasets.Value("string", id="sequence"),
|
87 |
+
}
|
88 |
+
),
|
89 |
+
codebase_urls=["https://github.com/jitsi/jiwer/"],
|
90 |
+
reference_urls=[
|
91 |
+
"https://en.wikipedia.org/wiki/Word_error_rate",
|
92 |
+
],
|
93 |
+
)
|
94 |
+
|
95 |
+
def _compute(self, predictions=None, references=None, concatenate_texts=False):
|
96 |
+
if concatenate_texts:
|
97 |
+
return compute_measures(references, predictions)["wer"]
|
98 |
+
else:
|
99 |
+
incorrect = 0
|
100 |
+
total = 0
|
101 |
+
for prediction, reference in zip(predictions, references):
|
102 |
+
measures = compute_measures(reference, prediction)
|
103 |
+
incorrect += measures["substitutions"] + measures["deletions"] + measures["insertions"]
|
104 |
+
total += measures["substitutions"] + measures["deletions"] + measures["hits"]
|
105 |
+
return incorrect / total
|
model-bin/metrics/wer/wer.py.lock
ADDED
File without changes
|
model-bin/pretrained/base/config.json
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "nguyenvulebinh/wav2vec2_vi",
|
3 |
+
"activation_dropout": 0.1,
|
4 |
+
"apply_spec_augment": true,
|
5 |
+
"architectures": [
|
6 |
+
"Wav2Vec2ForPreTraining"
|
7 |
+
],
|
8 |
+
"attention_dropout": 0.1,
|
9 |
+
"bos_token_id": 1,
|
10 |
+
"codevector_dim": 256,
|
11 |
+
"contrastive_logits_temperature": 0.1,
|
12 |
+
"conv_bias": false,
|
13 |
+
"conv_dim": [
|
14 |
+
512,
|
15 |
+
512,
|
16 |
+
512,
|
17 |
+
512,
|
18 |
+
512,
|
19 |
+
512,
|
20 |
+
512
|
21 |
+
],
|
22 |
+
"conv_kernel": [
|
23 |
+
10,
|
24 |
+
3,
|
25 |
+
3,
|
26 |
+
3,
|
27 |
+
3,
|
28 |
+
2,
|
29 |
+
2
|
30 |
+
],
|
31 |
+
"conv_stride": [
|
32 |
+
5,
|
33 |
+
2,
|
34 |
+
2,
|
35 |
+
2,
|
36 |
+
2,
|
37 |
+
2,
|
38 |
+
2
|
39 |
+
],
|
40 |
+
"ctc_loss_reduction": "sum",
|
41 |
+
"ctc_zero_infinity": false,
|
42 |
+
"diversity_loss_weight": 0.1,
|
43 |
+
"do_stable_layer_norm": false,
|
44 |
+
"eos_token_id": 2,
|
45 |
+
"feat_extract_activation": "gelu",
|
46 |
+
"feat_extract_dropout": 0.0,
|
47 |
+
"feat_extract_norm": "group",
|
48 |
+
"feat_proj_dropout": 0.1,
|
49 |
+
"feat_quantizer_dropout": 0.0,
|
50 |
+
"final_dropout": 0.1,
|
51 |
+
"gradient_checkpointing": false,
|
52 |
+
"hidden_act": "gelu",
|
53 |
+
"hidden_dropout": 0.1,
|
54 |
+
"hidden_dropout_prob": 0.1,
|
55 |
+
"hidden_size": 768,
|
56 |
+
"initializer_range": 0.02,
|
57 |
+
"intermediate_size": 3072,
|
58 |
+
"layer_norm_eps": 1e-05,
|
59 |
+
"layerdrop": 0.1,
|
60 |
+
"mask_feature_length": 10,
|
61 |
+
"mask_feature_prob": 0.0,
|
62 |
+
"mask_time_length": 10,
|
63 |
+
"mask_time_prob": 0.05,
|
64 |
+
"model_type": "wav2vec2",
|
65 |
+
"num_attention_heads": 12,
|
66 |
+
"num_codevector_groups": 2,
|
67 |
+
"num_codevectors_per_group": 320,
|
68 |
+
"num_conv_pos_embedding_groups": 16,
|
69 |
+
"num_conv_pos_embeddings": 128,
|
70 |
+
"num_feat_extract_layers": 7,
|
71 |
+
"num_hidden_layers": 12,
|
72 |
+
"num_negatives": 100,
|
73 |
+
"pad_token_id": 0,
|
74 |
+
"proj_codevector_dim": 256,
|
75 |
+
"torch_dtype": "float32",
|
76 |
+
"transformers_version": "4.9.1",
|
77 |
+
"vocab_size": 110
|
78 |
+
}
|
model-bin/pretrained/base/preprocessor_config.json
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"do_normalize": true,
|
3 |
+
"feature_size": 1,
|
4 |
+
"padding_side": "right",
|
5 |
+
"padding_value": 0.0,
|
6 |
+
"return_attention_mask": false,
|
7 |
+
"sampling_rate": 16000
|
8 |
+
}
|
model-bin/pretrained/base/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8b36355988e4d1f94d070ef677ab4d304bce440af0c3dd7bd1c98e295e907f09
|
3 |
+
size 380261837
|
requirments.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
soundfile
|
2 |
+
transformers==4.9.2
|
3 |
+
torch==1.9.0
|
4 |
+
datasets==1.11.0
|
5 |
+
jiwer
|
6 |
+
tensorboard
|