Spaces:
Runtime error
Runtime error
push to HF
Browse files- .gitattributes +4 -5
- .gitignore +541 -0
- README.md +119 -4
- app.py +468 -0
- local/UV.py +29 -0
- local/VCTK_preprocessing.py +62 -0
- local/data_preparation.py +78 -0
- local/decode.py +427 -0
- local/duration_calcutator.py +19 -0
- local/fine-tuning.py +346 -0
- local/get_ASR.py +232 -0
- local/get_ref_PPM.py +269 -0
- local/new_whisper_fine_tuning.py +481 -0
- local/new_whisper_fine_tuning_decode.py +509 -0
- local/post_processing.py +56 -0
- local/wer_plot_report.py +45 -0
- local/whisper_fine_tuning_large_with_negel.py +442 -0
- local/whisper_fine_tuning_michael_100.py +442 -0
- local/whisper_fine_tuning_negel.py +431 -0
- local/whisper_fine_tuning_negel_decode.py +473 -0
- packages.txt +2 -0
- requirements.txt +103 -0
- src/description.html +13 -0
- src/lightning_module.py +43 -0
- src/model.py +191 -0
.gitattributes
CHANGED
@@ -2,13 +2,11 @@
|
|
2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
*.npy filter=lfs diff=lfs merge=lfs -text
|
@@ -16,16 +14,14 @@
|
|
16 |
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
*.pkl filter=lfs diff=lfs merge=lfs -text
|
|
|
22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
*.wasm filter=lfs diff=lfs merge=lfs -text
|
@@ -33,3 +29,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
|
|
5 |
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
*.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
*.h5 filter=lfs diff=lfs merge=lfs -text
|
8 |
*.joblib filter=lfs diff=lfs merge=lfs -text
|
9 |
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
|
|
10 |
*.model filter=lfs diff=lfs merge=lfs -text
|
11 |
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
12 |
*.npy filter=lfs diff=lfs merge=lfs -text
|
|
|
14 |
*.onnx filter=lfs diff=lfs merge=lfs -text
|
15 |
*.ot filter=lfs diff=lfs merge=lfs -text
|
16 |
*.parquet filter=lfs diff=lfs merge=lfs -text
|
|
|
17 |
*.pickle filter=lfs diff=lfs merge=lfs -text
|
18 |
*.pkl filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
21 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
22 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
|
|
23 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
24 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
|
|
25 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
26 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
27 |
*.wasm filter=lfs diff=lfs merge=lfs -text
|
|
|
29 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
31 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
32 |
+
src/epoch=3-step=7459.ckpt filter=lfs diff=lfs merge=lfs -text
|
33 |
+
src/wav2vec_small.pt filter=lfs diff=lfs merge=lfs -text
|
34 |
+
data/p326_split filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,541 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# generated by: https://github.com/michaelliao/gitignore-online-generator
|
2 |
+
|
3 |
+
#################### Python.gitignore ####################
|
4 |
+
|
5 |
+
# Byte-compiled / optimized / DLL files
|
6 |
+
__pycache__/
|
7 |
+
*.py[cod]
|
8 |
+
*$py.class
|
9 |
+
|
10 |
+
# C extensions
|
11 |
+
*.so
|
12 |
+
|
13 |
+
# Distribution / packaging
|
14 |
+
.Python
|
15 |
+
build/
|
16 |
+
develop-eggs/
|
17 |
+
dist/
|
18 |
+
downloads/
|
19 |
+
eggs/
|
20 |
+
.eggs/
|
21 |
+
lib/
|
22 |
+
lib64/
|
23 |
+
parts/
|
24 |
+
sdist/
|
25 |
+
var/
|
26 |
+
wheels/
|
27 |
+
share/python-wheels/
|
28 |
+
*.egg-info/
|
29 |
+
.installed.cfg
|
30 |
+
*.egg
|
31 |
+
MANIFEST
|
32 |
+
|
33 |
+
# PyInstaller
|
34 |
+
# Usually these files are written by a python script from a template
|
35 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
36 |
+
*.manifest
|
37 |
+
*.spec
|
38 |
+
|
39 |
+
# Installer logs
|
40 |
+
pip-log.txt
|
41 |
+
pip-delete-this-directory.txt
|
42 |
+
|
43 |
+
# Unit test / coverage reports
|
44 |
+
htmlcov/
|
45 |
+
.tox/
|
46 |
+
.nox/
|
47 |
+
.coverage
|
48 |
+
.coverage.*
|
49 |
+
.cache
|
50 |
+
nosetests.xml
|
51 |
+
coverage.xml
|
52 |
+
*.cover
|
53 |
+
*.py,cover
|
54 |
+
.hypothesis/
|
55 |
+
.pytest_cache/
|
56 |
+
cover/
|
57 |
+
|
58 |
+
# Translations
|
59 |
+
*.mo
|
60 |
+
*.pot
|
61 |
+
|
62 |
+
# Django stuff:
|
63 |
+
*.log
|
64 |
+
local_settings.py
|
65 |
+
db.sqlite3
|
66 |
+
db.sqlite3-journal
|
67 |
+
|
68 |
+
# Flask stuff:
|
69 |
+
instance/
|
70 |
+
.webassets-cache
|
71 |
+
|
72 |
+
# Scrapy stuff:
|
73 |
+
.scrapy
|
74 |
+
|
75 |
+
# Sphinx documentation
|
76 |
+
docs/_build/
|
77 |
+
|
78 |
+
# PyBuilder
|
79 |
+
.pybuilder/
|
80 |
+
target/
|
81 |
+
|
82 |
+
# Jupyter Notebook
|
83 |
+
.ipynb_checkpoints
|
84 |
+
|
85 |
+
# IPython
|
86 |
+
profile_default/
|
87 |
+
ipython_config.py
|
88 |
+
|
89 |
+
# pyenv
|
90 |
+
# For a library or package, you might want to ignore these files since the code is
|
91 |
+
# intended to run in multiple environments; otherwise, check them in:
|
92 |
+
# .python-version
|
93 |
+
|
94 |
+
# pipenv
|
95 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
96 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
97 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
98 |
+
# install all needed dependencies.
|
99 |
+
#Pipfile.lock
|
100 |
+
|
101 |
+
# poetry
|
102 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
103 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
104 |
+
# commonly ignored for libraries.
|
105 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
106 |
+
#poetry.lock
|
107 |
+
|
108 |
+
# pdm
|
109 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
110 |
+
#pdm.lock
|
111 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
112 |
+
# in version control.
|
113 |
+
# https://pdm.fming.dev/#use-with-ide
|
114 |
+
.pdm.toml
|
115 |
+
|
116 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
117 |
+
__pypackages__/
|
118 |
+
|
119 |
+
# Celery stuff
|
120 |
+
celerybeat-schedule
|
121 |
+
celerybeat.pid
|
122 |
+
|
123 |
+
# SageMath parsed files
|
124 |
+
*.sage.py
|
125 |
+
|
126 |
+
# Environments
|
127 |
+
.env
|
128 |
+
.venv
|
129 |
+
env/
|
130 |
+
venv/
|
131 |
+
ENV/
|
132 |
+
env.bak/
|
133 |
+
venv.bak/
|
134 |
+
|
135 |
+
# Spyder project settings
|
136 |
+
.spyderproject
|
137 |
+
.spyproject
|
138 |
+
|
139 |
+
# Rope project settings
|
140 |
+
.ropeproject
|
141 |
+
|
142 |
+
# mkdocs documentation
|
143 |
+
/site
|
144 |
+
|
145 |
+
# mypy
|
146 |
+
.mypy_cache/
|
147 |
+
.dmypy.json
|
148 |
+
dmypy.json
|
149 |
+
|
150 |
+
# Pyre type checker
|
151 |
+
.pyre/
|
152 |
+
|
153 |
+
# pytype static type analyzer
|
154 |
+
.pytype/
|
155 |
+
|
156 |
+
# Cython debug symbols
|
157 |
+
cython_debug/
|
158 |
+
|
159 |
+
# PyCharm
|
160 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
161 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
162 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
163 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
164 |
+
#.idea/
|
165 |
+
|
166 |
+
#################### Archives.gitignore ####################
|
167 |
+
|
168 |
+
# It's better to unpack these files and commit the raw source because
|
169 |
+
# git has its own built in compression methods.
|
170 |
+
*.7z
|
171 |
+
*.jar
|
172 |
+
*.rar
|
173 |
+
*.zip
|
174 |
+
*.gz
|
175 |
+
*.gzip
|
176 |
+
*.tgz
|
177 |
+
*.bzip
|
178 |
+
*.bzip2
|
179 |
+
*.bz2
|
180 |
+
*.xz
|
181 |
+
*.lzma
|
182 |
+
*.cab
|
183 |
+
*.xar
|
184 |
+
|
185 |
+
# Packing-only formats
|
186 |
+
*.iso
|
187 |
+
*.tar
|
188 |
+
|
189 |
+
# Package management formats
|
190 |
+
*.dmg
|
191 |
+
*.xpi
|
192 |
+
*.gem
|
193 |
+
*.egg
|
194 |
+
*.deb
|
195 |
+
*.rpm
|
196 |
+
*.msi
|
197 |
+
*.msm
|
198 |
+
*.msp
|
199 |
+
*.txz
|
200 |
+
|
201 |
+
#################### Backup.gitignore ####################
|
202 |
+
|
203 |
+
*.bak
|
204 |
+
*.gho
|
205 |
+
*.ori
|
206 |
+
*.orig
|
207 |
+
*.tmp
|
208 |
+
|
209 |
+
#################### Emacs.gitignore ####################
|
210 |
+
|
211 |
+
# -*- mode: gitignore; -*-
|
212 |
+
*~
|
213 |
+
\#*\#
|
214 |
+
/.emacs.desktop
|
215 |
+
/.emacs.desktop.lock
|
216 |
+
*.elc
|
217 |
+
auto-save-list
|
218 |
+
tramp
|
219 |
+
.\#*
|
220 |
+
|
221 |
+
# Org-mode
|
222 |
+
.org-id-locations
|
223 |
+
*_archive
|
224 |
+
|
225 |
+
# flymake-mode
|
226 |
+
*_flymake.*
|
227 |
+
|
228 |
+
# eshell files
|
229 |
+
/eshell/history
|
230 |
+
/eshell/lastdir
|
231 |
+
|
232 |
+
# elpa packages
|
233 |
+
/elpa/
|
234 |
+
|
235 |
+
# reftex files
|
236 |
+
*.rel
|
237 |
+
|
238 |
+
# AUCTeX auto folder
|
239 |
+
/auto/
|
240 |
+
|
241 |
+
# cask packages
|
242 |
+
.cask/
|
243 |
+
dist/
|
244 |
+
|
245 |
+
# Flycheck
|
246 |
+
flycheck_*.el
|
247 |
+
|
248 |
+
# server auth directory
|
249 |
+
/server/
|
250 |
+
|
251 |
+
# projectiles files
|
252 |
+
.projectile
|
253 |
+
|
254 |
+
# directory configuration
|
255 |
+
.dir-locals.el
|
256 |
+
|
257 |
+
# network security
|
258 |
+
/network-security.data
|
259 |
+
|
260 |
+
|
261 |
+
#################### JetBrains.gitignore ####################
|
262 |
+
|
263 |
+
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
|
264 |
+
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
265 |
+
|
266 |
+
# User-specific stuff
|
267 |
+
.idea/**/workspace.xml
|
268 |
+
.idea/**/tasks.xml
|
269 |
+
.idea/**/usage.statistics.xml
|
270 |
+
.idea/**/dictionaries
|
271 |
+
.idea/**/shelf
|
272 |
+
|
273 |
+
# AWS User-specific
|
274 |
+
.idea/**/aws.xml
|
275 |
+
|
276 |
+
# Generated files
|
277 |
+
.idea/**/contentModel.xml
|
278 |
+
|
279 |
+
# Sensitive or high-churn files
|
280 |
+
.idea/**/dataSources/
|
281 |
+
.idea/**/dataSources.ids
|
282 |
+
.idea/**/dataSources.local.xml
|
283 |
+
.idea/**/sqlDataSources.xml
|
284 |
+
.idea/**/dynamic.xml
|
285 |
+
.idea/**/uiDesigner.xml
|
286 |
+
.idea/**/dbnavigator.xml
|
287 |
+
|
288 |
+
# Gradle
|
289 |
+
.idea/**/gradle.xml
|
290 |
+
.idea/**/libraries
|
291 |
+
|
292 |
+
# Gradle and Maven with auto-import
|
293 |
+
# When using Gradle or Maven with auto-import, you should exclude module files,
|
294 |
+
# since they will be recreated, and may cause churn. Uncomment if using
|
295 |
+
# auto-import.
|
296 |
+
# .idea/artifacts
|
297 |
+
# .idea/compiler.xml
|
298 |
+
# .idea/jarRepositories.xml
|
299 |
+
# .idea/modules.xml
|
300 |
+
# .idea/*.iml
|
301 |
+
# .idea/modules
|
302 |
+
# *.iml
|
303 |
+
# *.ipr
|
304 |
+
|
305 |
+
# CMake
|
306 |
+
cmake-build-*/
|
307 |
+
|
308 |
+
# Mongo Explorer plugin
|
309 |
+
.idea/**/mongoSettings.xml
|
310 |
+
|
311 |
+
# File-based project format
|
312 |
+
*.iws
|
313 |
+
|
314 |
+
# IntelliJ
|
315 |
+
out/
|
316 |
+
|
317 |
+
# mpeltonen/sbt-idea plugin
|
318 |
+
.idea_modules/
|
319 |
+
|
320 |
+
# JIRA plugin
|
321 |
+
atlassian-ide-plugin.xml
|
322 |
+
|
323 |
+
# Cursive Clojure plugin
|
324 |
+
.idea/replstate.xml
|
325 |
+
|
326 |
+
# SonarLint plugin
|
327 |
+
.idea/sonarlint/
|
328 |
+
|
329 |
+
# Crashlytics plugin (for Android Studio and IntelliJ)
|
330 |
+
com_crashlytics_export_strings.xml
|
331 |
+
crashlytics.properties
|
332 |
+
crashlytics-build.properties
|
333 |
+
fabric.properties
|
334 |
+
|
335 |
+
# Editor-based Rest Client
|
336 |
+
.idea/httpRequests
|
337 |
+
|
338 |
+
# Android studio 3.1+ serialized cache file
|
339 |
+
.idea/caches/build_file_checksums.ser
|
340 |
+
|
341 |
+
#################### Linux.gitignore ####################
|
342 |
+
|
343 |
+
*~
|
344 |
+
|
345 |
+
# temporary files which can be created if a process still has a handle open of a deleted file
|
346 |
+
.fuse_hidden*
|
347 |
+
|
348 |
+
# KDE directory preferences
|
349 |
+
.directory
|
350 |
+
|
351 |
+
# Linux trash folder which might appear on any partition or disk
|
352 |
+
.Trash-*
|
353 |
+
|
354 |
+
# .nfs files are created when an open file is removed but is still being accessed
|
355 |
+
.nfs*
|
356 |
+
|
357 |
+
#################### NotepadPP.gitignore ####################
|
358 |
+
|
359 |
+
# Notepad++ backups #
|
360 |
+
*.bak
|
361 |
+
|
362 |
+
#################### PuTTY.gitignore ####################
|
363 |
+
|
364 |
+
# Private key
|
365 |
+
*.ppk
|
366 |
+
|
367 |
+
#################### SublimeText.gitignore ####################
|
368 |
+
|
369 |
+
# Cache files for Sublime Text
|
370 |
+
*.tmlanguage.cache
|
371 |
+
*.tmPreferences.cache
|
372 |
+
*.stTheme.cache
|
373 |
+
|
374 |
+
# Workspace files are user-specific
|
375 |
+
*.sublime-workspace
|
376 |
+
|
377 |
+
# Project files should be checked into the repository, unless a significant
|
378 |
+
# proportion of contributors will probably not be using Sublime Text
|
379 |
+
# *.sublime-project
|
380 |
+
|
381 |
+
# SFTP configuration file
|
382 |
+
sftp-config.json
|
383 |
+
sftp-config-alt*.json
|
384 |
+
|
385 |
+
# Package control specific files
|
386 |
+
Package Control.last-run
|
387 |
+
Package Control.ca-list
|
388 |
+
Package Control.ca-bundle
|
389 |
+
Package Control.system-ca-bundle
|
390 |
+
Package Control.cache/
|
391 |
+
Package Control.ca-certs/
|
392 |
+
Package Control.merged-ca-bundle
|
393 |
+
Package Control.user-ca-bundle
|
394 |
+
oscrypto-ca-bundle.crt
|
395 |
+
bh_unicode_properties.cache
|
396 |
+
|
397 |
+
# Sublime-github package stores a github token in this file
|
398 |
+
# https://packagecontrol.io/packages/sublime-github
|
399 |
+
GitHub.sublime-settings
|
400 |
+
|
401 |
+
#################### Vim.gitignore ####################
|
402 |
+
|
403 |
+
# Swap
|
404 |
+
[._]*.s[a-v][a-z]
|
405 |
+
!*.svg # comment out if you don't need vector files
|
406 |
+
[._]*.sw[a-p]
|
407 |
+
[._]s[a-rt-v][a-z]
|
408 |
+
[._]ss[a-gi-z]
|
409 |
+
[._]sw[a-p]
|
410 |
+
|
411 |
+
# Session
|
412 |
+
Session.vim
|
413 |
+
Sessionx.vim
|
414 |
+
|
415 |
+
# Temporary
|
416 |
+
.netrwhist
|
417 |
+
*~
|
418 |
+
# Auto-generated tag files
|
419 |
+
tags
|
420 |
+
# Persistent undo
|
421 |
+
[._]*.un~
|
422 |
+
|
423 |
+
#################### VirtualEnv.gitignore ####################
|
424 |
+
|
425 |
+
# Virtualenv
|
426 |
+
# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
|
427 |
+
.Python
|
428 |
+
[Bb]in
|
429 |
+
[Ii]nclude
|
430 |
+
[Ll]ib
|
431 |
+
[Ll]ib64
|
432 |
+
[Ll]ocal
|
433 |
+
[Ss]cripts
|
434 |
+
pyvenv.cfg
|
435 |
+
.venv
|
436 |
+
pip-selfcheck.json
|
437 |
+
|
438 |
+
#################### VisualStudioCode.gitignore ####################
|
439 |
+
|
440 |
+
.vscode/*
|
441 |
+
!.vscode/settings.json
|
442 |
+
!.vscode/tasks.json
|
443 |
+
!.vscode/launch.json
|
444 |
+
!.vscode/extensions.json
|
445 |
+
!.vscode/*.code-snippets
|
446 |
+
|
447 |
+
# Local History for Visual Studio Code
|
448 |
+
.history/
|
449 |
+
|
450 |
+
# Built Visual Studio Code Extensions
|
451 |
+
*.vsix
|
452 |
+
|
453 |
+
#################### Windows.gitignore ####################
|
454 |
+
|
455 |
+
# Windows thumbnail cache files
|
456 |
+
Thumbs.db
|
457 |
+
Thumbs.db:encryptable
|
458 |
+
ehthumbs.db
|
459 |
+
ehthumbs_vista.db
|
460 |
+
|
461 |
+
# Dump file
|
462 |
+
*.stackdump
|
463 |
+
|
464 |
+
# Folder config file
|
465 |
+
[Dd]esktop.ini
|
466 |
+
|
467 |
+
# Recycle Bin used on file shares
|
468 |
+
$RECYCLE.BIN/
|
469 |
+
|
470 |
+
# Windows Installer files
|
471 |
+
*.cab
|
472 |
+
*.msi
|
473 |
+
*.msix
|
474 |
+
*.msm
|
475 |
+
*.msp
|
476 |
+
|
477 |
+
# Windows shortcuts
|
478 |
+
*.lnk
|
479 |
+
|
480 |
+
#################### macOS.gitignore ####################
|
481 |
+
|
482 |
+
# General
|
483 |
+
.DS_Store
|
484 |
+
.AppleDouble
|
485 |
+
.LSOverride
|
486 |
+
|
487 |
+
# Icon must end with two \r
|
488 |
+
Icon
|
489 |
+
|
490 |
+
|
491 |
+
# Thumbnails
|
492 |
+
._*
|
493 |
+
|
494 |
+
# Files that might appear in the root of a volume
|
495 |
+
.DocumentRevisions-V100
|
496 |
+
.fseventsd
|
497 |
+
.Spotlight-V100
|
498 |
+
.TemporaryItems
|
499 |
+
.Trashes
|
500 |
+
.VolumeIcon.icns
|
501 |
+
.com.apple.timemachine.donotpresent
|
502 |
+
|
503 |
+
# Directories potentially created on remote AFP share
|
504 |
+
.AppleDB
|
505 |
+
.AppleDesktop
|
506 |
+
Network Trash Folder
|
507 |
+
Temporary Items
|
508 |
+
.apdisk
|
509 |
+
|
510 |
+
#################### Custom.gitignore ####################
|
511 |
+
|
512 |
+
# add your custom gitignore here:
|
513 |
+
!.gitignore
|
514 |
+
!.gitsubmodules
|
515 |
+
|
516 |
+
# ignore data
|
517 |
+
data/
|
518 |
+
exp/
|
519 |
+
!src/lightning_module.py
|
520 |
+
*.wav
|
521 |
+
# ignore plots
|
522 |
+
*.png
|
523 |
+
# ignore csv
|
524 |
+
*.csv
|
525 |
+
|
526 |
+
|
527 |
+
!config/template.yaml
|
528 |
+
config
|
529 |
+
!local
|
530 |
+
|
531 |
+
## Currently
|
532 |
+
src/wav2vec_small.pt
|
533 |
+
*.ckpt
|
534 |
+
*.bak
|
535 |
+
|
536 |
+
fine_tuned/
|
537 |
+
|
538 |
+
#
|
539 |
+
user/
|
540 |
+
|
541 |
+
.vscode
|
README.md
CHANGED
@@ -1,12 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
-
title: Laronix
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Laronix Data Collection
|
2 |
+
|
3 |
+
This repository contains information about the Laronix data collection process, which involves collecting parallel data from AVA users. The dataset consists of two main sessions: scripted data and conversational data.
|
4 |
+
|
5 |
+
## Dataset
|
6 |
+
|
7 |
+
The dataset is organized as follows:
|
8 |
+
|
9 |
+
### 1. Scripted Data
|
10 |
+
|
11 |
+
The scripted data session includes 200 sentences collected from 5 articles. The references for both the audio and text versions of these sentences have already been uploaded or will be uploaded to the Laronix Recording system. (Ask [Kevin](kevin@laronix.com) for these files) The distribution of sentences from each article is as follows:
|
12 |
+
|
13 |
+
- Arthur the Rat: 56 sentences
|
14 |
+
- Cinder: 19 sentences
|
15 |
+
- Rainbow: 26 sentences
|
16 |
+
- Sentences: 59 sentences
|
17 |
+
- VCTK: 40 sentences
|
18 |
+
|
19 |
+
### 2. Conversational Data
|
20 |
+
|
21 |
+
The conversational data session focuses on natural conversations and involves the following components:
|
22 |
+
|
23 |
+
#### a. Q&A
|
24 |
+
|
25 |
+
In this component, a set of 50 sentences will be provided, consisting of questions and answers. During the recording, the partner will ask the questions (Q), and the patient will provide the answers (A). Both the questions and answers will be recorded.
|
26 |
+
|
27 |
+
#### b. Freestyle
|
28 |
+
|
29 |
+
The patients will have the freedom to talk about a given topic. They will be asked to respond with 5 to 10 sentences. The structure for this component can be referenced from the [IELTS speaking test](https://www.ieltsbuddy.com/IELTS-speaking-questions-with-answers.html).
|
30 |
+
|
31 |
+
## Data Inclusion Criteria
|
32 |
+
|
33 |
+
+ No hearing loss or history of active cancer.
|
34 |
+
+ 6 weeks of practice with AVA.
|
35 |
+
|
36 |
+
## Document for Laronix Recording System
|
37 |
+
|
38 |
+
The Laronix recording system is designed for data collection from potential users of the AVA Device, which replaces their voice cord.
|
39 |
+
|
40 |
+
### Input:
|
41 |
+
|
42 |
+
- Audio signal
|
43 |
+
- Reference ID
|
44 |
+
- Reference text
|
45 |
+
- Reference Phoneme per minute
|
46 |
+
|
47 |
+
### Output:
|
48 |
+
|
49 |
+
- wav_pause_plot: Wave signal plot with pauses detected by VAD algorithm (SNR = 40dB)
|
50 |
+
- Predicted Mean Opinion Score: Score estimating data quality on the MOS scale using an ML prediction model (1-5)
|
51 |
+
- Hypotheses: Text predicted by Automatic Speech Recognition model (wav2vev2.0 + CTC)
|
52 |
+
- WER: Word Error Rate (lower is better)
|
53 |
+
- Predicted Phonemes
|
54 |
+
- PPM: Phonemes per minute
|
55 |
+
- Message: Feedback from the system
|
56 |
+
|
57 |
+
## User Instruction
|
58 |
+
|
59 |
+
Please follow the instructions provided at the top of the APP page.
|
60 |
+
|
61 |
+
```
|
62 |
+
- Laronix_AUTOMOS
|
63 |
+
- data
|
64 |
+
- Template
|
65 |
+
- ref_wav/
|
66 |
+
- 1.wav
|
67 |
+
- 2.wav
|
68 |
+
- ...
|
69 |
+
- ref_txt.txt
|
70 |
+
- ref.csv # audio prosody features reference <generate by script>
|
71 |
+
- exp
|
72 |
+
- Template
|
73 |
+
- Audio_to_evaluate # RAW WAV DATA
|
74 |
+
- log.csv # Recording log
|
75 |
+
- output # wav.file <generate by script>
|
76 |
+
- model
|
77 |
+
- epoch=3-step=7459.ckpt # MOS estimate model
|
78 |
+
- wav2vec_small.pt # WER model
|
79 |
+
- local
|
80 |
+
- get_ref_PPM.py # script for generating data/<ref_dir>/ref.csv
|
81 |
+
- post_processing.py # script for generating exp/<ref_dir>/output/*.wav
|
82 |
+
```
|
83 |
+
|
84 |
---
|
85 |
+
title: Laronix Automos
|
86 |
+
emoji: 🏃
|
87 |
+
colorFrom: blue
|
88 |
colorTo: blue
|
89 |
sdk: gradio
|
90 |
+
sdk_version: 3.2
|
91 |
app_file: app.py
|
92 |
pinned: false
|
93 |
+
license: afl-3.0
|
94 |
---
|
95 |
|
96 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
97 |
+
|
98 |
+
# Laronix_AutoMOS
|
99 |
+
|
100 |
+
## Usage:
|
101 |
+
### Step 1: Prepare data and text
|
102 |
+
`<todo>`
|
103 |
+
### Step 2: Preprocessing
|
104 |
+
```
|
105 |
+
## Generating *.csv, Voice/Unvoice Plot (optional) and config (optional)
|
106 |
+
python local/get_ref_PPM.py --ref_txt <ref_text> \
|
107 |
+
--ref_wavs <ref_wavs> \
|
108 |
+
--output_dir <output_dir> \
|
109 |
+
--to_config <True/False> \
|
110 |
+
--UV_flag <True/False> \
|
111 |
+
--UV_thre <UV_thre>}
|
112 |
+
```
|
113 |
+
### Step 3: Launch recording session:
|
114 |
+
|
115 |
+
```
|
116 |
+
## Start app.py
|
117 |
+
python app.py <config.yaml>
|
118 |
+
```
|
119 |
+
+ **Find logging below and lick URL to start**
|
120 |
+
```
|
121 |
+
Launch examples
|
122 |
+
Running on local URL: http://127.0.0.1:7860/
|
123 |
+
...
|
124 |
+
(Logs...)
|
125 |
+
...
|
126 |
+
Running on public URL: https://87abe771e93229da.gradio.app
|
127 |
+
```
|
app.py
ADDED
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
TODO:
|
3 |
+
+ [x] Load Configuration
|
4 |
+
+ [ ] Checking
|
5 |
+
+ [ ] Better saving directory
|
6 |
+
"""
|
7 |
+
import numpy as np
|
8 |
+
from pathlib import Path
|
9 |
+
import jiwer
|
10 |
+
import pdb
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch
|
13 |
+
import torchaudio
|
14 |
+
import gradio as gr
|
15 |
+
from logging import PlaceHolder
|
16 |
+
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
17 |
+
import yaml
|
18 |
+
from transformers import pipeline
|
19 |
+
import librosa
|
20 |
+
import librosa.display
|
21 |
+
import matplotlib.pyplot as plt
|
22 |
+
|
23 |
+
|
24 |
+
# local import
|
25 |
+
import sys
|
26 |
+
|
27 |
+
sys.path.append("src")
|
28 |
+
import lightning_module
|
29 |
+
|
30 |
+
# Load automos
|
31 |
+
config_yaml = sys.argv[1]
|
32 |
+
with open(config_yaml, "r") as f:
|
33 |
+
# pdb.set_trace()
|
34 |
+
try:
|
35 |
+
config = yaml.safe_load(f)
|
36 |
+
except FileExistsError:
|
37 |
+
print("Config file Loading Error")
|
38 |
+
exit()
|
39 |
+
|
40 |
+
# Auto load examples
|
41 |
+
|
42 |
+
with open(config["ref_txt"], "r") as f:
|
43 |
+
refs = f.readlines()
|
44 |
+
refs_ids = [x.split()[0] for x in refs]
|
45 |
+
refs_txt = [" ".join(x.split()[1:]) for x in refs]
|
46 |
+
ref_feature = np.loadtxt(config["ref_feature"], delimiter=",", dtype="str")
|
47 |
+
ref_wavs = [str(x) for x in sorted(Path(config["ref_wavs"]).glob("**/*.wav"))]
|
48 |
+
|
49 |
+
dummy_wavs = [None for x in np.arange(len(ref_wavs))]
|
50 |
+
|
51 |
+
refs_ppm = np.array(ref_feature[:, -1][1:], dtype="str")
|
52 |
+
|
53 |
+
reference_id = gr.Textbox(value="ID", placeholder="Utter ID", label="Reference_ID")
|
54 |
+
|
55 |
+
reference_textbox = gr.Textbox(
|
56 |
+
value="Input reference here",
|
57 |
+
placeholder="Input reference here",
|
58 |
+
label="Reference",
|
59 |
+
)
|
60 |
+
reference_PPM = gr.Textbox(placeholder="Pneumatic Voice's PPM", label="Ref PPM")
|
61 |
+
|
62 |
+
# Set up interface
|
63 |
+
# remove dummpy wavs, ue the same ref_wavs for eval wavs
|
64 |
+
print("Preparing Examples")
|
65 |
+
examples = [
|
66 |
+
[w, w_, i, x, y] for w, w_, i, x, y in zip(ref_wavs, ref_wavs, refs_ids, refs_txt, refs_ppm)
|
67 |
+
]
|
68 |
+
|
69 |
+
p = pipeline(
|
70 |
+
"automatic-speech-recognition",
|
71 |
+
model="KevinGeng/whipser_medium_en_PAL300_step25",
|
72 |
+
device=0,
|
73 |
+
)
|
74 |
+
|
75 |
+
# WER part
|
76 |
+
transformation = jiwer.Compose(
|
77 |
+
[
|
78 |
+
jiwer.RemovePunctuation(),
|
79 |
+
jiwer.ToLowerCase(),
|
80 |
+
jiwer.RemoveWhiteSpace(replace_by_space=True),
|
81 |
+
jiwer.RemoveMultipleSpaces(),
|
82 |
+
jiwer.ReduceToListOfListOfWords(word_delimiter=" "),
|
83 |
+
]
|
84 |
+
)
|
85 |
+
|
86 |
+
# WPM part
|
87 |
+
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
|
88 |
+
phoneme_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
|
89 |
+
|
90 |
+
|
91 |
+
class ChangeSampleRate(nn.Module):
|
92 |
+
def __init__(self, input_rate: int, output_rate: int):
|
93 |
+
super().__init__()
|
94 |
+
self.output_rate = output_rate
|
95 |
+
self.input_rate = input_rate
|
96 |
+
|
97 |
+
def forward(self, wav: torch.tensor) -> torch.tensor:
|
98 |
+
# Only accepts 1-channel waveform input
|
99 |
+
wav = wav.view(wav.size(0), -1)
|
100 |
+
new_length = wav.size(-1) * self.output_rate // self.input_rate
|
101 |
+
indices = torch.arange(new_length) * (self.input_rate / self.output_rate)
|
102 |
+
round_down = wav[:, indices.long()]
|
103 |
+
round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
|
104 |
+
output = round_down * (1.0 - indices.fmod(1.0)).unsqueeze(0) + (
|
105 |
+
round_up * indices.fmod(1.0).unsqueeze(0)
|
106 |
+
)
|
107 |
+
return output
|
108 |
+
|
109 |
+
|
110 |
+
# MOS model
|
111 |
+
model = lightning_module.BaselineLightningModule.load_from_checkpoint(
|
112 |
+
"src/epoch=3-step=7459.ckpt"
|
113 |
+
).eval()
|
114 |
+
|
115 |
+
# Get Speech Interval
|
116 |
+
|
117 |
+
def get_speech_interval(signal, db):
|
118 |
+
audio_interv = librosa.effects.split(signal, top_db=db)
|
119 |
+
pause_end = [x[0] for x in audio_interv[1:]]
|
120 |
+
pause_start = [x[1] for x in audio_interv[0:-1]]
|
121 |
+
pause_interv = [[x, y] for x, y in zip(pause_start, pause_end)]
|
122 |
+
return audio_interv, pause_interv
|
123 |
+
|
124 |
+
# plot UV
|
125 |
+
|
126 |
+
|
127 |
+
def plot_UV(signal, audio_interv, sr):
|
128 |
+
fig, ax = plt.subplots(nrows=2, sharex=True)
|
129 |
+
librosa.display.waveshow(signal, sr=sr, ax=ax[0])
|
130 |
+
uv_flag = np.zeros(len(signal))
|
131 |
+
for i in audio_interv:
|
132 |
+
uv_flag[i[0] : i[1]] = 1
|
133 |
+
|
134 |
+
ax[1].plot(np.arange(len(signal)) / sr, uv_flag, "r")
|
135 |
+
ax[1].set_ylim([-0.1, 1.1])
|
136 |
+
return fig
|
137 |
+
|
138 |
+
def calc_mos(_, audio_path, id, ref, pre_ppm, fig=None):
|
139 |
+
if audio_path == None:
|
140 |
+
audio_path = _
|
141 |
+
print("using ref audio as eval audio since it's empty")
|
142 |
+
|
143 |
+
wav, sr = torchaudio.load(audio_path)
|
144 |
+
if wav.shape[0] != 1:
|
145 |
+
wav = wav[0, :]
|
146 |
+
print(wav.shape)
|
147 |
+
|
148 |
+
osr = 16000
|
149 |
+
batch = wav.unsqueeze(0).repeat(10, 1, 1)
|
150 |
+
csr = ChangeSampleRate(sr, osr)
|
151 |
+
out_wavs = csr(wav)
|
152 |
+
|
153 |
+
# ASR
|
154 |
+
trans = jiwer.ToLowerCase()(p(audio_path)["text"])
|
155 |
+
|
156 |
+
# WER
|
157 |
+
wer = jiwer.wer(
|
158 |
+
ref,
|
159 |
+
trans,
|
160 |
+
truth_transform=transformation,
|
161 |
+
hypothesis_transform=transformation,
|
162 |
+
)
|
163 |
+
# MOS
|
164 |
+
batch = {
|
165 |
+
"wav": out_wavs,
|
166 |
+
"domains": torch.tensor([0]),
|
167 |
+
"judge_id": torch.tensor([288]),
|
168 |
+
}
|
169 |
+
with torch.no_grad():
|
170 |
+
output = model(batch)
|
171 |
+
predic_mos = output.mean(dim=1).squeeze().detach().numpy() * 2 + 3
|
172 |
+
|
173 |
+
# Phonemes per minute (PPM)
|
174 |
+
with torch.no_grad():
|
175 |
+
logits = phoneme_model(out_wavs).logits
|
176 |
+
phone_predicted_ids = torch.argmax(logits, dim=-1)
|
177 |
+
phone_transcription = processor.batch_decode(phone_predicted_ids)
|
178 |
+
lst_phonemes = phone_transcription[0].split(" ")
|
179 |
+
|
180 |
+
# VAD for pause detection
|
181 |
+
wav_vad = torchaudio.functional.vad(wav, sample_rate=sr)
|
182 |
+
# pdb.set_trace()
|
183 |
+
a_h, p_h = get_speech_interval(wav_vad.numpy(), db=40)
|
184 |
+
# print(a_h)
|
185 |
+
# print(len(a_h))
|
186 |
+
fig_h = plot_UV(wav_vad.numpy().squeeze(), a_h, sr=sr)
|
187 |
+
ppm = len(lst_phonemes) / (wav_vad.shape[-1] / sr) * 60
|
188 |
+
|
189 |
+
error_msg = "!!! ERROR MESSAGE !!!\n"
|
190 |
+
if audio_path == _ or audio_path == None:
|
191 |
+
error_msg += "ERROR: Fail recording, Please start from the beginning again."
|
192 |
+
return (
|
193 |
+
fig_h,
|
194 |
+
predic_mos,
|
195 |
+
trans,
|
196 |
+
wer,
|
197 |
+
phone_transcription,
|
198 |
+
ppm,
|
199 |
+
error_msg,
|
200 |
+
)
|
201 |
+
if ppm >= float(pre_ppm) + float(config["thre"]["maxppm"]):
|
202 |
+
error_msg += "ERROR: Please speak slower.\n"
|
203 |
+
elif ppm <= float(pre_ppm) - float(config["thre"]["minppm"]):
|
204 |
+
error_msg += "ERROR: Please speak faster.\n"
|
205 |
+
elif predic_mos <= float(config["thre"]["AUTOMOS"]):
|
206 |
+
error_msg += "ERROR: Naturalness is too low, Please try again.\n"
|
207 |
+
elif wer >= float(config["thre"]["WER"]):
|
208 |
+
error_msg += "ERROR: Intelligibility is too low, Please try again\n"
|
209 |
+
else:
|
210 |
+
error_msg = (
|
211 |
+
"GOOD JOB! Please 【Save the Recording】.\nYou can start recording the next sample."
|
212 |
+
)
|
213 |
+
|
214 |
+
return (
|
215 |
+
fig_h,
|
216 |
+
predic_mos,
|
217 |
+
trans,
|
218 |
+
wer,
|
219 |
+
phone_transcription,
|
220 |
+
ppm,
|
221 |
+
error_msg,
|
222 |
+
)
|
223 |
+
|
224 |
+
with open("src/description.html", "r", encoding="utf-8") as f:
|
225 |
+
description = f.read()
|
226 |
+
# description
|
227 |
+
|
228 |
+
refs_ppm = np.array(ref_feature[:, -1][1:], dtype="str")
|
229 |
+
|
230 |
+
reference_id = gr.Textbox(value="ID", placeholder="Utter ID", label="Reference_ID", visible=False)
|
231 |
+
reference_textbox = gr.Textbox(
|
232 |
+
value="Input reference here",
|
233 |
+
placeholder="Input reference here",
|
234 |
+
label="Reference",
|
235 |
+
)
|
236 |
+
reference_PPM = gr.Textbox(placeholder="Pneumatic Voice's PPM", label="Ref PPM", visible=False)
|
237 |
+
|
238 |
+
# Flagging setup
|
239 |
+
|
240 |
+
# Interface
|
241 |
+
# Participant Information
|
242 |
+
def record_part_info(name, gender, first_lng):
|
243 |
+
message = "Participant information is successfully collected."
|
244 |
+
id_str = "%s_%s_%s" % (name, gender[0], first_lng[0])
|
245 |
+
|
246 |
+
if name == None:
|
247 |
+
message = "ERROR: Name Information incomplete!"
|
248 |
+
id_str = "ERROR"
|
249 |
+
|
250 |
+
if gender == None:
|
251 |
+
message = "ERROR: Please select gender"
|
252 |
+
id_str = "ERROR"
|
253 |
+
|
254 |
+
if len(gender) > 1:
|
255 |
+
message = "ERROR: Please select one gender only"
|
256 |
+
id_str = "ERROR"
|
257 |
+
|
258 |
+
if first_lng == None:
|
259 |
+
message = "ERROR: Please select your english proficiency"
|
260 |
+
id_str = "ERROR"
|
261 |
+
|
262 |
+
if len(first_lng) > 1:
|
263 |
+
message = "ERROR: Please select one english proficiency only"
|
264 |
+
id_str = "ERROR"
|
265 |
+
|
266 |
+
return message, id_str
|
267 |
+
|
268 |
+
|
269 |
+
# information page not using now
|
270 |
+
name = gr.Textbox(placeholder="Name", label="Name")
|
271 |
+
gender = gr.CheckboxGroup(["Male", "Female"], label="gender")
|
272 |
+
first_lng = gr.CheckboxGroup(
|
273 |
+
[
|
274 |
+
"B1 Intermediate",
|
275 |
+
"B2: Upper Intermediate",
|
276 |
+
"C1: Advanced",
|
277 |
+
"C2: Proficient",
|
278 |
+
],
|
279 |
+
label="English Proficiency (CEFR)",
|
280 |
+
)
|
281 |
+
|
282 |
+
msg = gr.Textbox(placeholder="Evaluation for valid participant", label="message")
|
283 |
+
id_str = gr.Textbox(placeholder="participant id", label="participant_id")
|
284 |
+
|
285 |
+
info = gr.Interface(
|
286 |
+
fn=record_part_info,
|
287 |
+
inputs=[name, gender, first_lng],
|
288 |
+
outputs=[msg, id_str],
|
289 |
+
title="Participant Information Page",
|
290 |
+
allow_flagging="never",
|
291 |
+
css="body {background-color: blue}",
|
292 |
+
)
|
293 |
+
# Experiment
|
294 |
+
if config["exp_id"] == None:
|
295 |
+
config["exp_id"] = Path(config_yaml).stem
|
296 |
+
|
297 |
+
## This is the theme for the interface
|
298 |
+
css = """
|
299 |
+
.ref_text textarea {font-size: 40px !important}
|
300 |
+
.message textarea {font-size: 40px !important}
|
301 |
+
"""
|
302 |
+
|
303 |
+
my_theme = gr.themes.Default().set(
|
304 |
+
button_primary_background_fill="#75DA99",
|
305 |
+
button_primary_background_fill_dark="#DEF2D7",
|
306 |
+
button_primary_text_color="black",
|
307 |
+
button_secondary_text_color="black",
|
308 |
+
)
|
309 |
+
|
310 |
+
# Callback for saving the recording
|
311 |
+
callback = gr.CSVLogger()
|
312 |
+
|
313 |
+
with gr.Blocks(css=css, theme=my_theme) as demo:
|
314 |
+
with gr.Column():
|
315 |
+
with gr.Row():
|
316 |
+
ref_audio = gr.Audio(
|
317 |
+
source="microphone",
|
318 |
+
type="filepath",
|
319 |
+
label="Reference_Audio",
|
320 |
+
container=True,
|
321 |
+
interactive=False,
|
322 |
+
visible=False,
|
323 |
+
)
|
324 |
+
with gr.Row():
|
325 |
+
eval_audio = gr.Audio(
|
326 |
+
source="microphone",
|
327 |
+
type="filepath",
|
328 |
+
container=True,
|
329 |
+
label="Audio_to_Evaluate",
|
330 |
+
)
|
331 |
+
b_redo = gr.ClearButton(
|
332 |
+
value="Redo", variant="stop", components=[eval_audio], size="sm"
|
333 |
+
)
|
334 |
+
reference_textbox = gr.Textbox(
|
335 |
+
value="Input reference here",
|
336 |
+
placeholder="Input reference here",
|
337 |
+
label="Reference",
|
338 |
+
interactive=True,
|
339 |
+
elem_classes="ref_text",
|
340 |
+
)
|
341 |
+
with gr.Accordion("Input for Development", open=False):
|
342 |
+
reference_id = gr.Textbox(
|
343 |
+
value="ID",
|
344 |
+
placeholder="Utter ID",
|
345 |
+
label="Reference_ID",
|
346 |
+
visible=True,
|
347 |
+
)
|
348 |
+
reference_PPM = gr.Textbox(
|
349 |
+
placeholder="Pneumatic Voice's PPM",
|
350 |
+
label="Ref PPM",
|
351 |
+
visible=True,
|
352 |
+
)
|
353 |
+
with gr.Row():
|
354 |
+
b = gr.Button(value="1.Submit", variant="primary", elem_classes="submit")
|
355 |
+
|
356 |
+
# TODO
|
357 |
+
# b_more = gr.Button(value="Show More", elem_classes="verbose")
|
358 |
+
with gr.Row():
|
359 |
+
inputs = [
|
360 |
+
ref_audio,
|
361 |
+
eval_audio,
|
362 |
+
reference_id,
|
363 |
+
reference_textbox,
|
364 |
+
reference_PPM,
|
365 |
+
]
|
366 |
+
e = gr.Examples(examples, inputs, examples_per_page=5)
|
367 |
+
|
368 |
+
with gr.Column():
|
369 |
+
with gr.Row():
|
370 |
+
## output block
|
371 |
+
msg = gr.Textbox(
|
372 |
+
placeholder="Recording Feedback",
|
373 |
+
label="Message",
|
374 |
+
interactive=False,
|
375 |
+
elem_classes="message",
|
376 |
+
)
|
377 |
+
with gr.Accordion("Output for Development", open=False):
|
378 |
+
wav_plot = gr.Plot(PlaceHolder="Wav/Pause Plot", label="wav_pause_plot", visible=True)
|
379 |
+
|
380 |
+
predict_mos = gr.Textbox(
|
381 |
+
placeholder="Predicted MOS",
|
382 |
+
label="Predicted MOS",
|
383 |
+
visible=True,
|
384 |
+
)
|
385 |
+
|
386 |
+
hyp = gr.Textbox(placeholder="Hypothesis", label="Hypothesis", visible=True)
|
387 |
+
|
388 |
+
wer = gr.Textbox(placeholder="Word Error Rate", label="WER", visible=True)
|
389 |
+
|
390 |
+
predict_pho = gr.Textbox(
|
391 |
+
placeholder="Predicted Phonemes",
|
392 |
+
label="Predicted Phonemes",
|
393 |
+
visible=True,
|
394 |
+
)
|
395 |
+
|
396 |
+
ppm = gr.Textbox(
|
397 |
+
placeholder="Phonemes per minutes",
|
398 |
+
label="PPM",
|
399 |
+
visible=True,
|
400 |
+
)
|
401 |
+
outputs = [
|
402 |
+
wav_plot,
|
403 |
+
predict_mos,
|
404 |
+
hyp,
|
405 |
+
wer,
|
406 |
+
predict_pho,
|
407 |
+
ppm,
|
408 |
+
msg,
|
409 |
+
]
|
410 |
+
|
411 |
+
# b = gr.Button("Submit")
|
412 |
+
b.click(fn=calc_mos, inputs=inputs, outputs=outputs, api_name="Submit")
|
413 |
+
|
414 |
+
# Logger
|
415 |
+
callback.setup(
|
416 |
+
components=[
|
417 |
+
eval_audio,
|
418 |
+
reference_id,
|
419 |
+
reference_textbox,
|
420 |
+
reference_PPM,
|
421 |
+
predict_mos,
|
422 |
+
hyp,
|
423 |
+
wer,
|
424 |
+
ppm,
|
425 |
+
msg],
|
426 |
+
flagging_dir="./exp/%s" % config["exp_id"],
|
427 |
+
)
|
428 |
+
|
429 |
+
with gr.Row():
|
430 |
+
b2 = gr.Button("2. Save the Recording", variant="primary", elem_id="save")
|
431 |
+
js_confirmed_saving = "(x) => confirm('Recording Saved!')"
|
432 |
+
# eval_audio,
|
433 |
+
b2.click(
|
434 |
+
lambda *args: callback.flag(args),
|
435 |
+
inputs=[
|
436 |
+
eval_audio,
|
437 |
+
reference_id,
|
438 |
+
reference_textbox,
|
439 |
+
reference_PPM,
|
440 |
+
predict_mos,
|
441 |
+
hyp,
|
442 |
+
wer,
|
443 |
+
ppm,
|
444 |
+
msg,
|
445 |
+
],
|
446 |
+
outputs=None,
|
447 |
+
preprocess=False,
|
448 |
+
api_name="flagging",
|
449 |
+
)
|
450 |
+
with gr.Row():
|
451 |
+
b3 = gr.ClearButton(
|
452 |
+
[
|
453 |
+
ref_audio,
|
454 |
+
eval_audio,
|
455 |
+
reference_id,
|
456 |
+
reference_textbox,
|
457 |
+
reference_PPM,
|
458 |
+
predict_mos,
|
459 |
+
hyp,
|
460 |
+
wer,
|
461 |
+
ppm,
|
462 |
+
msg,
|
463 |
+
],
|
464 |
+
value="3.Clear All",
|
465 |
+
elem_id="clear",
|
466 |
+
)
|
467 |
+
|
468 |
+
demo.launch(share=True)
|
local/UV.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import librosa.display
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
|
5 |
+
# Plot_UV
|
6 |
+
|
7 |
+
|
8 |
+
def plot_UV(signal, audio_interv, sr):
|
9 |
+
fig, ax = plt.subplots(nrows=2, sharex=True)
|
10 |
+
librosa.display.waveshow(signal, sr=sr, ax=ax[0])
|
11 |
+
ax[0].set_title("Signal")
|
12 |
+
ax[1].set_title("U/V")
|
13 |
+
uv_flag = np.zeros(len(signal))
|
14 |
+
for i in audio_interv:
|
15 |
+
uv_flag[i[0]: i[1]] = 1
|
16 |
+
|
17 |
+
ax[1].plot(np.arange(len(signal))/sr, uv_flag, "r")
|
18 |
+
ax[1].set_ylim([-0.1, 1.1])
|
19 |
+
return fig
|
20 |
+
|
21 |
+
# Get Speech Interval
|
22 |
+
|
23 |
+
|
24 |
+
def get_speech_interval(signal, db):
|
25 |
+
audio_interv = librosa.effects.split(signal, top_db=db)
|
26 |
+
pause_end = [x[0] for x in audio_interv[1:]]
|
27 |
+
pause_start = [x[1] for x in audio_interv[0: -1]]
|
28 |
+
pause_interv = [[x, y] for x, y in zip(pause_start, pause_end)]
|
29 |
+
return audio_interv, pause_interv
|
local/VCTK_preprocessing.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Kevin @ Laronix Dec. 2022
|
2 |
+
# Data processing at Laronix
|
3 |
+
import csv
|
4 |
+
import soundfile as sf
|
5 |
+
import pandas as pd
|
6 |
+
from pathlib import Path
|
7 |
+
import librosa
|
8 |
+
import sys
|
9 |
+
import numpy as np
|
10 |
+
import pdb
|
11 |
+
from rich.progress import track
|
12 |
+
|
13 |
+
wavdir = sys.argv[1]
|
14 |
+
txtdir = sys.argv[2]
|
15 |
+
thre_len = int(sys.argv[3])
|
16 |
+
origin_sr = int(sys.argv[4])
|
17 |
+
target_sr = int(sys.argv[5])
|
18 |
+
|
19 |
+
wavs = sorted(Path(wavdir).glob("**/*.wav"))
|
20 |
+
txts = sorted(Path(txtdir).glob("**/*.txt"))
|
21 |
+
target_dir = "./data/%s_%d_%d_len%d" % (
|
22 |
+
Path(wavdir).stem,
|
23 |
+
origin_sr,
|
24 |
+
target_sr,
|
25 |
+
thre_len,
|
26 |
+
)
|
27 |
+
|
28 |
+
Path.mkdir(Path(target_dir), exist_ok=True)
|
29 |
+
# pdb.set_trace()
|
30 |
+
tables = []
|
31 |
+
for x, y in track(
|
32 |
+
zip(wavs, txts), description="Processing...", total=len(wavs)
|
33 |
+
):
|
34 |
+
label = 1
|
35 |
+
with open(y, "r") as f:
|
36 |
+
txt = f.readline()
|
37 |
+
if len(txt.split(" ")) <= thre_len:
|
38 |
+
label = 1
|
39 |
+
record = [x, Path(x).stem, txt, len(txt.split(" ")), label]
|
40 |
+
tables.append(record)
|
41 |
+
# Select length <= 10 words sentences for training
|
42 |
+
if len(txt.split(" ")) <= thre_len:
|
43 |
+
wav, sr = librosa.load(x, sr=origin_sr)
|
44 |
+
wav_ = librosa.resample(wav, orig_sr=sr, target_sr=target_sr)
|
45 |
+
sf.write(
|
46 |
+
Path(target_dir) / Path((x).stem + ".wav"),
|
47 |
+
data=wav_,
|
48 |
+
samplerate=target_sr,
|
49 |
+
)
|
50 |
+
|
51 |
+
D = pd.DataFrame(
|
52 |
+
tables, columns=["wav_path", "id", "text", "len", "length_label"]
|
53 |
+
)
|
54 |
+
D.to_csv(target_dir + ".datalog", sep=",")
|
55 |
+
print("Check data log at %s" % (target_dir + ".datalog"))
|
56 |
+
|
57 |
+
D.get(["id", "text"]).to_csv(
|
58 |
+
target_dir + ".txt", sep="\t", header=False, index=False, quoting=3
|
59 |
+
)
|
60 |
+
|
61 |
+
print("Generate id_text at %s" % (target_dir + ".txt"))
|
62 |
+
print("Finish")
|
local/data_preparation.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pdb
|
3 |
+
import shutil
|
4 |
+
import pandas as pd
|
5 |
+
from datasets import Dataset, load_dataset
|
6 |
+
|
7 |
+
audio_dir = "./data/Patient_sil_trim_16k_normed_5_snr_40/"
|
8 |
+
# split_files = {"train": "data/Patient_sil_trim_16k_normed_5_snr_40/train.csv",
|
9 |
+
# "test": "data/Patient_sil_trim_16k_normed_5_snr_40/test.csv",
|
10 |
+
# "dev": "data/Patient_sil_trim_16k_normed_5_snr_40/dev.csv"}
|
11 |
+
src_dataset = load_dataset("audiofolder", data_dir=audio_dir, split="train")
|
12 |
+
pdb.set_trace()
|
13 |
+
def train_dev_test_split(
|
14 |
+
dataset: Dataset, dev_rate=0.1, test_rate=0.1, seed=1, metadata_output=False, root_dir=None
|
15 |
+
):
|
16 |
+
"""
|
17 |
+
input: dataset
|
18 |
+
dev_rate,
|
19 |
+
test_rate
|
20 |
+
seed
|
21 |
+
-------
|
22 |
+
Output:
|
23 |
+
dataset_dict{"train", "dev", "test"}
|
24 |
+
"""
|
25 |
+
train_dev_test = dataset.train_test_split(test_size=test_rate, seed=seed)
|
26 |
+
test = train_dev_test["test"]
|
27 |
+
train_dev = train_dev_test["train"]
|
28 |
+
|
29 |
+
if len(train_dev) <= int(len(dataset) * dev_rate):
|
30 |
+
train = Dataset.from_dict({"audio": [], "transcription": []})
|
31 |
+
dev = train_dev
|
32 |
+
else:
|
33 |
+
train_dev = train_dev.train_test_split(
|
34 |
+
test_size=int(len(dataset) * dev_rate), seed=seed
|
35 |
+
)
|
36 |
+
train = train_dev["train"]
|
37 |
+
dev = train_dev["test"]
|
38 |
+
|
39 |
+
train_size = len(train)
|
40 |
+
dev_size = len(dev)
|
41 |
+
test_size = len(test)
|
42 |
+
|
43 |
+
print(f"Train Size: {len(train)}")
|
44 |
+
print(f"Dev Size: {len(dev)}")
|
45 |
+
print(f"Test Size: {len(test)}")
|
46 |
+
import pdb
|
47 |
+
if metadata_output:
|
48 |
+
pdb.set_trace()
|
49 |
+
train_df = pd.DateFrame(train)
|
50 |
+
dev_df = pd.DataFrame(dev)
|
51 |
+
test_df = pd.DataFrame(test)
|
52 |
+
|
53 |
+
try:
|
54 |
+
os.path.exists(root_dir)
|
55 |
+
except:
|
56 |
+
raise FileNotFoundError
|
57 |
+
|
58 |
+
# Create directories for train, dev, and test data
|
59 |
+
import pdb
|
60 |
+
if not os.path.exists(f'{root_dir}/train'):
|
61 |
+
os.makedirs(f'{root_dir}/train')
|
62 |
+
if not os.path.exists(f'{root_dir}/dev'):
|
63 |
+
os.makedirs(f'{root_dir}/dev')
|
64 |
+
if not os.path.exists(f'{root_dir}/test'):
|
65 |
+
os.makedirs(f'{root_dir}/test')
|
66 |
+
|
67 |
+
pdb.set_trace()
|
68 |
+
train_df.to_csv(f'{root_dir}/train/metadata.csv', index=False)
|
69 |
+
|
70 |
+
dev_df.to_csv(f'{root_dir}/dev/metadata.csv', index=False)
|
71 |
+
|
72 |
+
test_df.to_csv(f'{root_dir}/test/metadata.csv', index=False)
|
73 |
+
|
74 |
+
return train, dev, test
|
75 |
+
|
76 |
+
train, dev, test = train_dev_test_split(src_dataset, dev_rate=0.1, test_rate=0.1, seed=1, metadata_output=True, root_dir=audio_dir)
|
77 |
+
|
78 |
+
pdb.set_trace()
|
local/decode.py
ADDED
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fine_tuning_dir = "fine_tuned/SSD/model/Negel_79_AVA_script_conv_train_conv_dev/checkpoint-50"
|
2 |
+
|
3 |
+
from typing import Any, Dict, List, Union
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from transformers import Seq2SeqTrainer
|
6 |
+
from transformers import WhisperProcessor, WhisperForConditionalGeneration, WhisperTokenizer, WhisperFeatureExtractor, Seq2SeqTrainingArguments, Seq2SeqTrainer, WhisperModel
|
7 |
+
import evaluate
|
8 |
+
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
9 |
+
from random import sample
|
10 |
+
from sys import flags
|
11 |
+
import gradio as gr
|
12 |
+
import torchaudio
|
13 |
+
import torch.nn as nn
|
14 |
+
import jiwer
|
15 |
+
import numpy as np
|
16 |
+
from rich import print as rprint
|
17 |
+
from rich.progress import track
|
18 |
+
from transformers import pipeline
|
19 |
+
import argparse
|
20 |
+
import yaml
|
21 |
+
import torch
|
22 |
+
from pathlib import Path
|
23 |
+
from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC, AutoProcessor
|
24 |
+
from datasets import load_dataset, concatenate_datasets
|
25 |
+
from datasets import Dataset, Audio
|
26 |
+
import pdb
|
27 |
+
import string
|
28 |
+
import librosa
|
29 |
+
# local import
|
30 |
+
import sys
|
31 |
+
|
32 |
+
sys.path.append("src")
|
33 |
+
import lightning_module
|
34 |
+
|
35 |
+
torch.cuda.set_device("cuda:0")
|
36 |
+
|
37 |
+
audio_dir = "./data/Patient_sil_trim_16k_normed_5_snr_40"
|
38 |
+
healthy_dir = "./data/Healthy"
|
39 |
+
Fary_PAL_30 = "./data/Fary_PAL_p326_20230110_30"
|
40 |
+
John_p326 = "./data/John_p326/output"
|
41 |
+
John_video = "./data/20230103_video"
|
42 |
+
negel_79 = "./data/4_negel_79"
|
43 |
+
|
44 |
+
patient_T = "data/Patient_T/Patient_T"
|
45 |
+
patient_L = "data/Patient_L/Patient_L"
|
46 |
+
# Get Transcription, WER and PPM
|
47 |
+
"""
|
48 |
+
TODO:
|
49 |
+
[DONE]: Automatic generating Config
|
50 |
+
"""
|
51 |
+
|
52 |
+
|
53 |
+
sys.path.append("./src")
|
54 |
+
|
55 |
+
|
56 |
+
wer = evaluate.load("wer")
|
57 |
+
|
58 |
+
# root_path = Path(__file__).parents[1]
|
59 |
+
|
60 |
+
|
61 |
+
class ChangeSampleRate(nn.Module):
|
62 |
+
def __init__(self, input_rate: int, output_rate: int):
|
63 |
+
super().__init__()
|
64 |
+
self.output_rate = output_rate
|
65 |
+
self.input_rate = input_rate
|
66 |
+
|
67 |
+
def forward(self, wav: torch.tensor) -> torch.tensor:
|
68 |
+
# Only accepts 1-channel waveform input
|
69 |
+
wav = wav.view(wav.size(0), -1)
|
70 |
+
new_length = wav.size(-1) * self.output_rate // self.input_rate
|
71 |
+
indices = torch.arange(new_length) * (
|
72 |
+
self.input_rate / self.output_rate
|
73 |
+
)
|
74 |
+
round_down = wav[:, indices.long()]
|
75 |
+
round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
|
76 |
+
output = round_down * (1.0 - indices.fmod(1.0)).unsqueeze(
|
77 |
+
0
|
78 |
+
) + round_up * indices.fmod(1.0).unsqueeze(0)
|
79 |
+
return output
|
80 |
+
|
81 |
+
# resample and clean text data
|
82 |
+
|
83 |
+
|
84 |
+
def dataclean(example):
|
85 |
+
# pdb.set_trace()
|
86 |
+
if example['audio']['sampling_rate'] != 16000:
|
87 |
+
resampled_audio = librosa.resample(y=example['audio']['array'],
|
88 |
+
orig_sr=example['audio']['sampling_rate'],
|
89 |
+
target_sr=16000)
|
90 |
+
# torchaudio.transforms.Resample(example['audio']['sampling_rate'], 16000)
|
91 |
+
# resampled_audio = resampler(example['audio']['array'])
|
92 |
+
|
93 |
+
return {"audio": {"path": example['audio']['path'], "array": resampled_audio, "sampling_rate": 16000},
|
94 |
+
"transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))}
|
95 |
+
else:
|
96 |
+
return {"transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))}
|
97 |
+
|
98 |
+
processor = AutoFeatureExtractor.from_pretrained(
|
99 |
+
"facebook/wav2vec2-base-960h"
|
100 |
+
)
|
101 |
+
|
102 |
+
def prepare_dataset(batch):
|
103 |
+
audio = batch["audio"]
|
104 |
+
batch = processor(
|
105 |
+
audio["array"], sampling_rate=audio["sampling_rate"], text=batch['transcription'])
|
106 |
+
batch["input_length"] = len(batch["input_values"][0])
|
107 |
+
return batch
|
108 |
+
|
109 |
+
|
110 |
+
negel_79_dataset = load_dataset("audiofolder", data_dir=negel_79, split="train")
|
111 |
+
negel_79_dataset = negel_79_dataset.map(dataclean)
|
112 |
+
|
113 |
+
def train_dev_test_split(dataset: Dataset, dev_rate=0.1, test_rate=0.1, seed=1):
|
114 |
+
"""
|
115 |
+
input: dataset
|
116 |
+
dev_rate,
|
117 |
+
test_rate
|
118 |
+
seed
|
119 |
+
-------
|
120 |
+
Output:
|
121 |
+
dataset_dict{"train", "dev", "test"}
|
122 |
+
"""
|
123 |
+
train_dev_test = dataset.train_test_split(test_size=test_rate, seed=seed)
|
124 |
+
test = train_dev_test["test"]
|
125 |
+
train_dev = train_dev_test['train']
|
126 |
+
|
127 |
+
# pdb.set_trace()
|
128 |
+
if len(train_dev) <= int(len(dataset)*dev_rate):
|
129 |
+
train = Dataset.from_dict({"audio": [], "transcription": []})
|
130 |
+
dev = train_dev
|
131 |
+
else:
|
132 |
+
train_dev = train_dev.train_test_split(test_size=int(len(dataset)*dev_rate), seed=seed)
|
133 |
+
train = train_dev['train']
|
134 |
+
dev = train_dev['test']
|
135 |
+
return train, dev, test
|
136 |
+
|
137 |
+
# pdb.set_trace()
|
138 |
+
# P1tony_train, P1tony_dev, P1tony_test = train_dev_test_split(P1tony_dataset, dev_rate=0.5, test_rate=0.5, seed=1)
|
139 |
+
# P1tony_train_ = concatenate_datasets([P1tony_train,P1tony_scripted])
|
140 |
+
# pdb.set_trace()
|
141 |
+
|
142 |
+
Negel_79_train, Negel_79_dev, Negel_79_test = train_dev_test_split(negel_79_dataset, dev_rate=0.1, test_rate=0.1, seed=1)
|
143 |
+
|
144 |
+
# src_dataset = load_dataset("audiofolder", data_dir=audio_dir, split="train")
|
145 |
+
# src_dataset = src_dataset.map(dataclean)
|
146 |
+
|
147 |
+
# healthy_test_dataset = load_dataset(
|
148 |
+
# "audiofolder", data_dir=healthy_dir, split='train')
|
149 |
+
# healthy_test_dataset = healthy_test_dataset.map(dataclean)
|
150 |
+
|
151 |
+
# Fary_PAL_test_dataset = load_dataset(
|
152 |
+
# "audiofolder", data_dir=Fary_PAL_30, split='train')
|
153 |
+
# Fary_PAL_test_dataset = Fary_PAL_test_dataset.map(dataclean)
|
154 |
+
|
155 |
+
# John_p326_test_dataset = load_dataset(
|
156 |
+
# "audiofolder", data_dir=John_p326, split='train')
|
157 |
+
# John_p326_test_dataset = John_p326_test_dataset.map(dataclean)
|
158 |
+
|
159 |
+
# John_video_test_dataset = load_dataset(
|
160 |
+
# "audiofolder", data_dir=John_video, split='train')
|
161 |
+
# John_video_test_dataset = John_video_test_dataset.map(dataclean)
|
162 |
+
|
163 |
+
# patient_T_test_dataset = load_dataset("audiofolder", data_dir=patient_T, split='train')
|
164 |
+
# patient_T_test_dataset = patient_T_test_dataset.map(dataclean)
|
165 |
+
|
166 |
+
# patient_L_test_dataset = load_dataset("audiofolder", data_dir=patient_L, split='train')
|
167 |
+
# patient_L_test_dataset = patient_L_test_dataset.map(dataclean)
|
168 |
+
# pdb.set_trace()
|
169 |
+
|
170 |
+
# train_dev / test
|
171 |
+
# ds = src_dataset.train_test_split(test_size=0.1, seed=1)
|
172 |
+
|
173 |
+
# dataset_libri = load_dataset(
|
174 |
+
# "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
175 |
+
|
176 |
+
# train_dev = ds['train']
|
177 |
+
# # train / dev
|
178 |
+
# train_dev = train_dev.train_test_split(
|
179 |
+
# test_size=int(len(src_dataset)*0.1), seed=1)
|
180 |
+
# # train/dev/test
|
181 |
+
# train = train_dev['train']
|
182 |
+
# test = ds['test']
|
183 |
+
# dev = train_dev['test']
|
184 |
+
|
185 |
+
# # pdb.set_trace()
|
186 |
+
# encoded_train = train.map(prepare_dataset, num_proc=4)
|
187 |
+
# encoded_dev = dev.map(prepare_dataset, num_proc=4)
|
188 |
+
# encoded_test = test.map(prepare_dataset, num_proc=4)
|
189 |
+
|
190 |
+
# encoded_healthy = healthy_test_dataset.map(prepare_dataset, num_proc=4)
|
191 |
+
# encoded_Fary = Fary_PAL_test_dataset.map(prepare_dataset, num_proc=4)
|
192 |
+
# encoded_John_p326 = John_p326_test_dataset.map(prepare_dataset, num_proc=4)
|
193 |
+
# encoded_John_video = John_video_test_dataset.map(prepare_dataset, num_proc=4)
|
194 |
+
# pdb.set_trace()
|
195 |
+
|
196 |
+
WER = evaluate.load("wer")
|
197 |
+
|
198 |
+
# Whisper decoding
|
199 |
+
|
200 |
+
processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
|
201 |
+
model = WhisperForConditionalGeneration.from_pretrained(
|
202 |
+
"openai/whisper-medium").to("cuda:0")
|
203 |
+
tokenizer = WhisperTokenizer.from_pretrained(
|
204 |
+
"openai/whisper-medium", language="English", task="transcribe")
|
205 |
+
|
206 |
+
# Need to push tokenizer to hugginface/model to activate online API
|
207 |
+
|
208 |
+
# tokenizer.push_to_hub("KevinGeng/whipser_medium_en_PAL300_step25")
|
209 |
+
# import pdb
|
210 |
+
# pdb.set_trace()
|
211 |
+
|
212 |
+
feature_extractor = WhisperFeatureExtractor.from_pretrained(
|
213 |
+
"openai/whisper-medium")
|
214 |
+
|
215 |
+
|
216 |
+
def whisper_prepare_dataset(batch):
|
217 |
+
# load and resample audio data from 48 to 16kHz
|
218 |
+
audio = batch["audio"]
|
219 |
+
|
220 |
+
# compute log-Mel input features from input audio array
|
221 |
+
batch["input_features"] = feature_extractor(
|
222 |
+
audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
|
223 |
+
|
224 |
+
# encode target text to label ids
|
225 |
+
batch["labels"] = tokenizer(batch["transcription"]).input_ids
|
226 |
+
return batch
|
227 |
+
|
228 |
+
|
229 |
+
torch.cuda.empty_cache()
|
230 |
+
|
231 |
+
training_args = Seq2SeqTrainingArguments(
|
232 |
+
# change to a repo name of your choice
|
233 |
+
output_dir="./whisper-medium-PAL128-25step",
|
234 |
+
per_device_train_batch_size=8,
|
235 |
+
gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
|
236 |
+
learning_rate=1e-5,
|
237 |
+
warmup_steps=100,
|
238 |
+
max_steps=1000,
|
239 |
+
gradient_checkpointing=True,
|
240 |
+
fp16=True,
|
241 |
+
evaluation_strategy="steps",
|
242 |
+
per_device_eval_batch_size=8,
|
243 |
+
predict_with_generate=True,
|
244 |
+
generation_max_length=512,
|
245 |
+
save_steps=100,
|
246 |
+
eval_steps=25,
|
247 |
+
logging_steps=100,
|
248 |
+
report_to=["tensorboard"],
|
249 |
+
load_best_model_at_end=True,
|
250 |
+
metric_for_best_model="wer",
|
251 |
+
greater_is_better=False,
|
252 |
+
push_to_hub=True,
|
253 |
+
)
|
254 |
+
|
255 |
+
|
256 |
+
def my_map_to_pred(batch):
|
257 |
+
# pdb.set_trace()
|
258 |
+
audio = batch["audio"]
|
259 |
+
input_features = processor(
|
260 |
+
audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt").input_features
|
261 |
+
# batch["reference"] = whisper_processor.tokenizer._normalize(batch['text'])
|
262 |
+
batch["reference"] = processor.tokenizer._normalize(batch['transcription'])
|
263 |
+
|
264 |
+
with torch.no_grad():
|
265 |
+
# predicted_ids = whisper_model.generate(input_features.to("cuda"))[0]
|
266 |
+
predicted_ids = model.generate(input_features.to("cuda"))[0]
|
267 |
+
transcription = model.decode(predicted_ids)
|
268 |
+
batch["prediction"] = model.tokenizer._normalize(transcription)
|
269 |
+
return batch
|
270 |
+
|
271 |
+
|
272 |
+
@dataclass
|
273 |
+
class DataCollatorSpeechSeq2SeqWithPadding:
|
274 |
+
processor: Any
|
275 |
+
|
276 |
+
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
277 |
+
# split inputs and labels since they have to be of different lengths and need different padding methods
|
278 |
+
# first treat the audio inputs by simply returning torch tensors
|
279 |
+
input_features = [{"input_features": feature["input_features"]}
|
280 |
+
for feature in features]
|
281 |
+
batch = self.processor.feature_extractor.pad(
|
282 |
+
input_features, return_tensors="pt")
|
283 |
+
|
284 |
+
# get the tokenized label sequences
|
285 |
+
label_features = [{"input_ids": feature["labels"]}
|
286 |
+
for feature in features]
|
287 |
+
# pad the labels to max length
|
288 |
+
labels_batch = self.processor.tokenizer.pad(
|
289 |
+
label_features, return_tensors="pt")
|
290 |
+
|
291 |
+
# replace padding with -100 to ignore loss correctly
|
292 |
+
labels = labels_batch["input_ids"].masked_fill(
|
293 |
+
labels_batch.attention_mask.ne(1), -100)
|
294 |
+
|
295 |
+
# if bos token is appended in previous tokenization step,
|
296 |
+
# cut bos token here as it's append later anyways
|
297 |
+
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
|
298 |
+
labels = labels[:, 1:]
|
299 |
+
|
300 |
+
batch["labels"] = labels
|
301 |
+
|
302 |
+
return batch
|
303 |
+
|
304 |
+
|
305 |
+
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
|
306 |
+
|
307 |
+
|
308 |
+
def compute_metrics(pred):
|
309 |
+
pdb.set_trace()
|
310 |
+
pred_ids = pred.predictions
|
311 |
+
label_ids = pred.label_ids
|
312 |
+
|
313 |
+
# replace -100 with the pad_token_id
|
314 |
+
label_ids[label_ids == -100] = tokenizer.pad_token_id
|
315 |
+
|
316 |
+
# we do not want to group tokens when computing the metrics
|
317 |
+
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
318 |
+
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
|
319 |
+
|
320 |
+
wer = 100 * WER.compute(predictions=pred_str, references=label_str)
|
321 |
+
|
322 |
+
return {"wer": wer}
|
323 |
+
|
324 |
+
encode_negel_79_train = Negel_79_train.map(whisper_prepare_dataset, num_proc=4)
|
325 |
+
encode_negel_79_dev = Negel_79_dev.map(whisper_prepare_dataset, num_proc=4)
|
326 |
+
encode_negel_79_test = Negel_79_test.map(whisper_prepare_dataset, num_proc=4)
|
327 |
+
pdb.set_trace()
|
328 |
+
torch.cuda.empty_cache()
|
329 |
+
|
330 |
+
torch.cuda.empty_cache()
|
331 |
+
|
332 |
+
fine_tuned_model = WhisperForConditionalGeneration.from_pretrained(
|
333 |
+
fine_tuning_dir
|
334 |
+
).to("cuda")
|
335 |
+
# "fine_tuned/SSD/model/whipser_medium_TEP_patient_T"
|
336 |
+
# "./fine_tuned/whipser_medium_en_PAL300_step25_step2_VCTK/checkpoint-400"
|
337 |
+
#"./fine_tuned/whipser_medium_en_PAL300_step25_step2_VCTK/checkpoint-200"
|
338 |
+
|
339 |
+
|
340 |
+
def fine_tuned_map_to_pred(batch):
|
341 |
+
# pdb.set_trace()
|
342 |
+
audio = batch["audio"]
|
343 |
+
input_features = processor(
|
344 |
+
audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt").input_features
|
345 |
+
# batch["reference"] = whisper_processor.tokenizer._normalize(batch['text'])
|
346 |
+
batch["reference"] = processor.tokenizer._normalize(batch['transcription'])
|
347 |
+
|
348 |
+
with torch.no_grad():
|
349 |
+
# predicted_ids = whisper_model.generate(input_features.to("cuda"))[0]
|
350 |
+
predicted_ids = fine_tuned_model.generate(input_features.to("cuda"))[0]
|
351 |
+
transcription = tokenizer.decode(predicted_ids)
|
352 |
+
batch["prediction"] = tokenizer._normalize(transcription)
|
353 |
+
return batch
|
354 |
+
|
355 |
+
|
356 |
+
# output_dir="./fine_tuned/whipser_medium_en_PAL300_step25_step2_VCTK/checkpoint-400",
|
357 |
+
testing_args = Seq2SeqTrainingArguments(
|
358 |
+
# change to a repo name of your choice
|
359 |
+
output_dir="fine_tuned/SSD/model/whipser_medium_TEP_patient_TL_TL",
|
360 |
+
per_device_train_batch_size=8,
|
361 |
+
gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
|
362 |
+
learning_rate=1e-5,
|
363 |
+
warmup_steps=100,
|
364 |
+
max_steps=1000,
|
365 |
+
gradient_checkpointing=True,
|
366 |
+
fp16=True,
|
367 |
+
evaluation_strategy="steps",
|
368 |
+
per_device_eval_batch_size=8,
|
369 |
+
predict_with_generate=True,
|
370 |
+
generation_max_length=512,
|
371 |
+
save_steps=100,
|
372 |
+
eval_steps=25,
|
373 |
+
logging_steps=100,
|
374 |
+
report_to=["tensorboard"],
|
375 |
+
load_best_model_at_end=True,
|
376 |
+
metric_for_best_model="wer",
|
377 |
+
greater_is_better=False,
|
378 |
+
push_to_hub=False,
|
379 |
+
)
|
380 |
+
|
381 |
+
predict_trainer = Seq2SeqTrainer(
|
382 |
+
args=testing_args,
|
383 |
+
model=fine_tuned_model,
|
384 |
+
data_collator=data_collator,
|
385 |
+
compute_metrics=compute_metrics,
|
386 |
+
tokenizer=processor.feature_extractor,
|
387 |
+
)
|
388 |
+
|
389 |
+
# trainer.train()
|
390 |
+
# fine tuned
|
391 |
+
# z_result = encoded_test.map(fine_tuned_map_to_pred)
|
392 |
+
pdb.set_trace()
|
393 |
+
z_result= encode_negel_79_test.map(fine_tuned_map_to_pred)
|
394 |
+
# 0.4692737430167598
|
395 |
+
z = WER.compute(references=z_result['reference'], predictions=z_result['prediction'])
|
396 |
+
# pdb.set_trace()
|
397 |
+
# z_hel_result = encoded_healthy.map(fine_tuned_map_to_pred)
|
398 |
+
# z_hel = WER.compute(references=z_hel_result['reference'], predictions=z_hel_result['prediction'])
|
399 |
+
# # 0.1591610117211598
|
400 |
+
|
401 |
+
# # pdb.set_trace()
|
402 |
+
# # z_fary_result = encoded_Fary.map(fine_tuned_map_to_pred)
|
403 |
+
# # z_far = WER.compute(references=z_fary_result['reference'], predictions=z_fary_result['prediction'])
|
404 |
+
# # 0.1791044776119403
|
405 |
+
# z_patient_LT = encoded_patient_TL_test.map(fine_tuned_map_to_pred)
|
406 |
+
# z_patient_LT_result = WER.compute(references=z_patient_LT['reference'], predictions=z_patient_LT['prediction'])
|
407 |
+
# z_patient_L = encoded_patient_L_test.map(fine_tuned_map_to_pred)
|
408 |
+
# z_patient_L_result = WER.compute(references=z_patient_L['reference'], predictions=z_patient_L['prediction'])
|
409 |
+
# z_patient_T = encoded_patient_T_test.map(fine_tuned_map_to_pred)
|
410 |
+
# z_patient_T_result = WER.compute(references=z_patient_T['reference'], predictions=z_patient_T['prediction'])
|
411 |
+
|
412 |
+
# # z_john_p326_result = encoded_John_p326.map(fine_tuned_map_to_pred)
|
413 |
+
# # pdb.set_trace()
|
414 |
+
|
415 |
+
# # z_john_p326 = WER.compute(references=z_john_p326_result['reference'], predictions=z_john_p326_result['prediction'])
|
416 |
+
# # 0.4648241206030151
|
417 |
+
pdb.set_trace()
|
418 |
+
|
419 |
+
# # y_John_video= fine_tuned_trainer.predict(encoded_John_video)
|
420 |
+
# # metrics={'test_loss': 2.665189743041992, 'test_wer': 0.7222222222222222, 'test_runtime': 0.1633, 'test_samples_per_second': 48.979, 'test_steps_per_second': 6.122})
|
421 |
+
# pdb.set_trace()
|
422 |
+
|
423 |
+
# p326 training
|
424 |
+
# metrics={'test_loss': 0.4804028868675232, 'test_wer': 0.21787709497206703, 'test_runtime': 0.3594, 'test_samples_per_second': 44.517, 'test_steps_per_second': 5.565})
|
425 |
+
# hel metrics={'test_loss': 1.6363693475723267, 'test_wer': 0.17951881554595928, 'test_runtime': 3.8451, 'test_samples_per_second': 41.611, 'test_steps_per_second': 5.201})
|
426 |
+
# Fary: metrics={'t est_loss': 1.4633615016937256, 'test_wer': 0.5572139303482587, 'test_runtime': 0.6627, 'test_samples_per_second': 45.27, 'test_steps_per_second': 6.036})
|
427 |
+
# p326 large: metrics={'test_loss': 0.6568527817726135, 'test_wer': 0.2889447236180904, 'test_runtime': 0.7169, 'test_samples_per_second': 51.613, 'test_steps_per_second': 6.975})
|
local/duration_calcutator.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import librosa
|
3 |
+
|
4 |
+
folder_path = "/home/kevingeng/Disk2/laronix/laronix_automos/data/Patient_sil_trim_16k_normed_5_snr_40/Sentences" # Replace with the path to your folder
|
5 |
+
total_duration = 0.0
|
6 |
+
|
7 |
+
# Iterate through all files in the folder
|
8 |
+
for filename in os.listdir(folder_path):
|
9 |
+
file_path = os.path.join(folder_path, filename)
|
10 |
+
if os.path.isfile(file_path):
|
11 |
+
try:
|
12 |
+
# Load the audio file and get its duration
|
13 |
+
audio_data, _ = librosa.load(file_path)
|
14 |
+
duration = librosa.get_duration(audio_data)
|
15 |
+
total_duration += duration
|
16 |
+
except Exception as e:
|
17 |
+
print(f"Error processing file '{filename}': {e}")
|
18 |
+
|
19 |
+
print(f"Total duration of audio files in the folder: {total_duration} seconds.")
|
local/fine-tuning.py
ADDED
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
TODO:
|
3 |
+
+ [x] Load Configuration
|
4 |
+
+ [ ] Multi ASR Engine
|
5 |
+
+ [ ] Batch / Real Time support
|
6 |
+
"""
|
7 |
+
from pathlib import Path
|
8 |
+
from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC, AutoProcessor
|
9 |
+
from datasets import load_dataset
|
10 |
+
from datasets import Dataset, Audio
|
11 |
+
import pdb
|
12 |
+
import string
|
13 |
+
# local import
|
14 |
+
import sys
|
15 |
+
|
16 |
+
sys.path.append("src")
|
17 |
+
|
18 |
+
# token_model = AutoModelForCTC.from_pretrained(
|
19 |
+
# "facebook/wav2vec2-base-960h"
|
20 |
+
# )
|
21 |
+
|
22 |
+
# ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
23 |
+
|
24 |
+
audio_path = "/Users/kevingeng/Laronix/Laronix_PAL_ASR_Offline_Plot/data/samples/3_Healthy1.wav"
|
25 |
+
|
26 |
+
audio_dir= "/Users/kevingeng/Laronix/laronix_automos/data/Patient_sil_trim_16k_normed_5_snr_40/"
|
27 |
+
# tgt_audio_dir= "/Users/kevingeng/Laronix/Dataset/Pneumatic/automos"
|
28 |
+
|
29 |
+
# src_audio_list = sorted(Path(src_audio_dir).glob("**/*.wav"))
|
30 |
+
# src_audio_list = [str(x) for x in src_audio_list]
|
31 |
+
# src_audio_dict = {"audio": src_audio_list}
|
32 |
+
# src_dataset = Dataset.from_dict(src_audio_dict).cast_column("audio", Audio())
|
33 |
+
|
34 |
+
# tgt_audio_list = sorted(Path(tgt_audio_dir).glob("**/*.wav"))
|
35 |
+
# tgt_audio_list = [str(x) for x in tgt_audio_list]
|
36 |
+
# tgt_audio_dict = {"audio": tgt_audio_list}
|
37 |
+
# tgt_dataset = Dataset.from_dict(tgt_audio_dict).cast_column("audio", Audio())
|
38 |
+
|
39 |
+
# Get Transcription, WER and PPM
|
40 |
+
"""
|
41 |
+
TODO:
|
42 |
+
[DONE]: Automatic generating Config
|
43 |
+
"""
|
44 |
+
|
45 |
+
import yaml
|
46 |
+
import argparse
|
47 |
+
import sys
|
48 |
+
from pathlib import Path
|
49 |
+
|
50 |
+
sys.path.append("./src")
|
51 |
+
import lightning_module
|
52 |
+
from UV import plot_UV, get_speech_interval
|
53 |
+
from transformers import pipeline
|
54 |
+
from rich.progress import track
|
55 |
+
from rich import print as rprint
|
56 |
+
import numpy as np
|
57 |
+
import jiwer
|
58 |
+
import pdb
|
59 |
+
import torch.nn as nn
|
60 |
+
import torch
|
61 |
+
import torchaudio
|
62 |
+
import gradio as gr
|
63 |
+
from sys import flags
|
64 |
+
from random import sample
|
65 |
+
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
66 |
+
|
67 |
+
# root_path = Path(__file__).parents[1]
|
68 |
+
|
69 |
+
class ChangeSampleRate(nn.Module):
|
70 |
+
def __init__(self, input_rate: int, output_rate: int):
|
71 |
+
super().__init__()
|
72 |
+
self.output_rate = output_rate
|
73 |
+
self.input_rate = input_rate
|
74 |
+
|
75 |
+
def forward(self, wav: torch.tensor) -> torch.tensor:
|
76 |
+
# Only accepts 1-channel waveform input
|
77 |
+
wav = wav.view(wav.size(0), -1)
|
78 |
+
new_length = wav.size(-1) * self.output_rate // self.input_rate
|
79 |
+
indices = torch.arange(new_length) * (
|
80 |
+
self.input_rate / self.output_rate
|
81 |
+
)
|
82 |
+
round_down = wav[:, indices.long()]
|
83 |
+
round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
|
84 |
+
output = round_down * (1.0 - indices.fmod(1.0)).unsqueeze(
|
85 |
+
0
|
86 |
+
) + round_up * indices.fmod(1.0).unsqueeze(0)
|
87 |
+
return output
|
88 |
+
|
89 |
+
|
90 |
+
model = lightning_module.BaselineLightningModule.load_from_checkpoint(
|
91 |
+
"./src/epoch=3-step=7459.ckpt"
|
92 |
+
).eval()
|
93 |
+
|
94 |
+
|
95 |
+
def calc_wer(audio_path, ref):
|
96 |
+
wav, sr = torchaudio.load(audio_path)
|
97 |
+
osr = 16_000
|
98 |
+
batch = wav.unsqueeze(0).repeat(10, 1, 1)
|
99 |
+
csr = ChangeSampleRate(sr, osr)
|
100 |
+
out_wavs = csr(wav)
|
101 |
+
# ASR
|
102 |
+
trans = p(audio_path)["text"]
|
103 |
+
# WER
|
104 |
+
wer = jiwer.wer(
|
105 |
+
ref,
|
106 |
+
trans,
|
107 |
+
truth_transform=transformation,
|
108 |
+
hypothesis_transform=transformation,
|
109 |
+
)
|
110 |
+
return trans, wer
|
111 |
+
|
112 |
+
# if __name__ == "__main__":
|
113 |
+
# # Argparse
|
114 |
+
# parser = argparse.ArgumentParser(
|
115 |
+
# prog="get_ref_PPM",
|
116 |
+
# description="Generate Phoneme per Minute (and Voice/Unvoice plot)",
|
117 |
+
# epilog="",
|
118 |
+
# )
|
119 |
+
# parser.add_argument(
|
120 |
+
# "--tag",
|
121 |
+
# type=str,
|
122 |
+
# default=None,
|
123 |
+
# required=False,
|
124 |
+
# help="ID tag for output *.csv",
|
125 |
+
# )
|
126 |
+
|
127 |
+
# parser.add_argument("--ref_txt", type=str, required=True, help="Reference TXT")
|
128 |
+
# parser.add_argument(
|
129 |
+
# "--ref_wavs", type=str, required=True, help="Reference WAVs"
|
130 |
+
# )
|
131 |
+
|
132 |
+
# parser.add_argument(
|
133 |
+
# "--output_dir",
|
134 |
+
# type=str,
|
135 |
+
# required=True,
|
136 |
+
# help="Output Directory for *.csv",
|
137 |
+
# )
|
138 |
+
# parser.add_argument(
|
139 |
+
# "--to_config",
|
140 |
+
# choices=["True", "False"],
|
141 |
+
# default="False",
|
142 |
+
# help="Generating Config from .txt and wavs/*wav",
|
143 |
+
# )
|
144 |
+
|
145 |
+
|
146 |
+
# args = parser.parse_args()
|
147 |
+
|
148 |
+
# refs = np.loadtxt(args.ref_txt, delimiter="\n", dtype="str")
|
149 |
+
# refs_ids = [x.split()[0] for x in refs]
|
150 |
+
# refs_txt = [" ".join(x.split()[1:]) for x in refs]
|
151 |
+
# ref_wavs = [str(x) for x in sorted(Path(args.ref_wavs).glob("**/*.wav"))]
|
152 |
+
# # pdb.set_trace()
|
153 |
+
# try:
|
154 |
+
# len(refs) == len(ref_wavs)
|
155 |
+
# except ValueError:
|
156 |
+
# print("Error: Text and Wavs don't match")
|
157 |
+
# exit()
|
158 |
+
|
159 |
+
# # ASR part
|
160 |
+
# p = pipeline("automatic-speech-recognition")
|
161 |
+
|
162 |
+
# # WER part
|
163 |
+
# transformation = jiwer.Compose(
|
164 |
+
# [
|
165 |
+
# jiwer.ToLowerCase(),
|
166 |
+
# jiwer.RemoveWhiteSpace(replace_by_space=True),
|
167 |
+
# jiwer.RemoveMultipleSpaces(),
|
168 |
+
# jiwer.ReduceToListOfListOfWords(word_delimiter=" "),
|
169 |
+
# ]
|
170 |
+
# )
|
171 |
+
|
172 |
+
# # WPM part
|
173 |
+
# processor = Wav2Vec2Processor.from_pretrained(
|
174 |
+
# "facebook/wav2vec2-xlsr-53-espeak-cv-ft"
|
175 |
+
# )
|
176 |
+
# phoneme_model = Wav2Vec2ForCTC.from_pretrained(
|
177 |
+
# "facebook/wav2vec2-xlsr-53-espeak-cv-ft"
|
178 |
+
# )
|
179 |
+
# # phoneme_model = pipeline(model="facebook/wav2vec2-xlsr-53-espeak-cv-ft")
|
180 |
+
|
181 |
+
# description = """
|
182 |
+
# MOS prediction demo using UTMOS-strong w/o phoneme encoder model, \
|
183 |
+
# which is trained on the main track dataset.
|
184 |
+
# This demo only accepts .wav format. Best at 16 kHz sampling rate.
|
185 |
+
|
186 |
+
# Paper is available [here](https://arxiv.org/abs/2204.02152)
|
187 |
+
|
188 |
+
# Add ASR based on wav2vec-960, currently only English available.
|
189 |
+
# Add WER interface.
|
190 |
+
# """
|
191 |
+
|
192 |
+
# referance_id = gr.Textbox(
|
193 |
+
# value="ID", placeholder="Utter ID", label="Reference_ID"
|
194 |
+
# )
|
195 |
+
# referance_textbox = gr.Textbox(
|
196 |
+
# value="", placeholder="Input reference here", label="Reference"
|
197 |
+
# )
|
198 |
+
# # Set up interface
|
199 |
+
# result = []
|
200 |
+
# result.append("id, trans, wer")
|
201 |
+
|
202 |
+
|
203 |
+
# for id, x, y in track(
|
204 |
+
# zip(refs_ids, ref_wavs, refs_txt),
|
205 |
+
# total=len(refs_ids),
|
206 |
+
# description="Loading references information",
|
207 |
+
# ):
|
208 |
+
# trans, wer = calc_wer(x, y)
|
209 |
+
# record = ",".join(
|
210 |
+
# [
|
211 |
+
# id,
|
212 |
+
# str(trans),
|
213 |
+
# str(wer)
|
214 |
+
# ]
|
215 |
+
# )
|
216 |
+
# result.append(record)
|
217 |
+
|
218 |
+
# # Output
|
219 |
+
# if args.tag == None:
|
220 |
+
# args.tag = Path(args.ref_wavs).stem
|
221 |
+
# # Make output_dir
|
222 |
+
# # pdb.set_trace()
|
223 |
+
# Path.mkdir(Path(args.output_dir), exist_ok=True)
|
224 |
+
# # pdb.set_trace()
|
225 |
+
# with open("%s/%s.csv" % (args.output_dir, args.tag), "w") as f:
|
226 |
+
# print("\n".join(result), file=f)
|
227 |
+
|
228 |
+
# # Generating config
|
229 |
+
# if args.to_config == "True":
|
230 |
+
# config_dict = {
|
231 |
+
# "exp_id": args.tag,
|
232 |
+
# "ref_txt": args.ref_txt,
|
233 |
+
# "ref_feature": "%s/%s.csv" % (args.output_dir, args.tag),
|
234 |
+
# "ref_wavs": args.ref_wavs,
|
235 |
+
# "thre": {
|
236 |
+
# "minppm": 100,
|
237 |
+
# "maxppm": 100,
|
238 |
+
# "WER": 0.1,
|
239 |
+
# "AUTOMOS": 4.0,
|
240 |
+
# },
|
241 |
+
# "auth": {"username": None, "password": None},
|
242 |
+
# }
|
243 |
+
# with open("./config/%s.yaml" % args.tag, "w") as config_f:
|
244 |
+
# rprint("Dumping as config ./config/%s.yaml" % args.tag)
|
245 |
+
# rprint(config_dict)
|
246 |
+
# yaml.dump(config_dict, stream=config_f)
|
247 |
+
# rprint("Change parameter ./config/%s.yaml if necessary" % args.tag)
|
248 |
+
# print("Reference Dumping Finished")
|
249 |
+
def dataclean(example):
|
250 |
+
return {"transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))}
|
251 |
+
|
252 |
+
# processor = AutoFeatureExtractor.from_pretrained(
|
253 |
+
# "facebook/wav2vec2-base-960h"
|
254 |
+
# )
|
255 |
+
processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base")
|
256 |
+
|
257 |
+
def prepare_dataset(batch):
|
258 |
+
audio = batch["audio"]
|
259 |
+
batch = processor(audio["array"], sampling_rate = audio["sampling_rate"], text=batch['transcription'])
|
260 |
+
batch["input_length"] = len(batch["input_values"][0])
|
261 |
+
return batch
|
262 |
+
|
263 |
+
src_dataset = load_dataset("audiofolder", data_dir=audio_dir, split="train")
|
264 |
+
src_dataset = src_dataset.map(dataclean)
|
265 |
+
# train_dev / test
|
266 |
+
ds = src_dataset.train_test_split(test_size=0.1)
|
267 |
+
|
268 |
+
train_dev = ds['train']
|
269 |
+
# train / dev
|
270 |
+
train_dev = train_dev.train_test_split(test_size=int(len(src_dataset)*0.1))
|
271 |
+
# train/dev/test
|
272 |
+
train = train_dev['train']
|
273 |
+
test = ds['test']
|
274 |
+
dev = train_dev['test']
|
275 |
+
|
276 |
+
# pdb.set_trace()
|
277 |
+
import numpy as np
|
278 |
+
|
279 |
+
|
280 |
+
def compute_metrics(pred):
|
281 |
+
pred_logits = pred.predictions
|
282 |
+
pred_ids = np.argmax(pred_logits, axis=-1)
|
283 |
+
|
284 |
+
pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
|
285 |
+
|
286 |
+
pred_str = processor.batch_decode(pred_ids)
|
287 |
+
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
|
288 |
+
|
289 |
+
wer = wer.compute(predictions=pred_str, references=label_str)
|
290 |
+
|
291 |
+
return {"wer": wer}
|
292 |
+
|
293 |
+
|
294 |
+
pdb.set_trace()
|
295 |
+
# TOKENLIZER("data/samples/5_Laronix1.wav")
|
296 |
+
# pdb.set_trace()
|
297 |
+
# tokenizer
|
298 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/wav2vec2-base-960h")
|
299 |
+
|
300 |
+
encoded_train = train.map(prepare_dataset, num_proc=4)
|
301 |
+
|
302 |
+
from transformers import AutoModelForCTC, TrainingArguments, Trainer
|
303 |
+
|
304 |
+
model = AutoModelForCTC.from_pretrained(
|
305 |
+
"facebook/wav2vec2-base",
|
306 |
+
ctc_loss_reduction="mean",
|
307 |
+
pad_token_id=processor.tokenizer.pad_token_id,
|
308 |
+
)
|
309 |
+
pdb.set_trace()
|
310 |
+
|
311 |
+
training_args = TrainingArguments(
|
312 |
+
output_dir="my_awesome_asr_mind_model",
|
313 |
+
per_device_train_batch_size=8,
|
314 |
+
gradient_accumulation_steps=2,
|
315 |
+
learning_rate=1e-5,
|
316 |
+
warmup_steps=500,
|
317 |
+
max_steps=2000,
|
318 |
+
gradient_checkpointing=True,
|
319 |
+
fp16=True,
|
320 |
+
group_by_length=True,
|
321 |
+
evaluation_strategy="steps",
|
322 |
+
per_device_eval_batch_size=8,
|
323 |
+
save_steps=1000,
|
324 |
+
eval_steps=1000,
|
325 |
+
logging_steps=25,
|
326 |
+
load_best_model_at_end=True,
|
327 |
+
metric_for_best_model="wer",
|
328 |
+
greater_is_better=False,
|
329 |
+
push_to_hub=True,
|
330 |
+
)
|
331 |
+
|
332 |
+
pdb.set_trace()
|
333 |
+
trainer = Trainer(
|
334 |
+
model=model,
|
335 |
+
args=training_args,
|
336 |
+
train_dataset=encoded_train["train"],
|
337 |
+
eval_dataset=encoded_train["test"],
|
338 |
+
tokenizer=processor.feature_extractor,
|
339 |
+
compute_metrics=compute_metrics,
|
340 |
+
)
|
341 |
+
pdb.set_trace()
|
342 |
+
# data_collator=data_collator,
|
343 |
+
|
344 |
+
trainer.train()
|
345 |
+
# x = tokenizer(test['transcription'][0])
|
346 |
+
|
local/get_ASR.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Get Transcription, WER and PPM
|
2 |
+
"""
|
3 |
+
TODO:
|
4 |
+
[DONE]: Automatic generating Config
|
5 |
+
"""
|
6 |
+
|
7 |
+
import yaml
|
8 |
+
import argparse
|
9 |
+
import sys
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
sys.path.append("./src")
|
13 |
+
import lightning_module
|
14 |
+
from UV import plot_UV, get_speech_interval
|
15 |
+
from transformers import pipeline
|
16 |
+
from rich.progress import track
|
17 |
+
from rich import print as rprint
|
18 |
+
import numpy as np
|
19 |
+
import jiwer
|
20 |
+
import pdb
|
21 |
+
import torch.nn as nn
|
22 |
+
import torch
|
23 |
+
import torchaudio
|
24 |
+
import gradio as gr
|
25 |
+
from sys import flags
|
26 |
+
from random import sample
|
27 |
+
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
28 |
+
|
29 |
+
# root_path = Path(__file__).parents[1]
|
30 |
+
|
31 |
+
class ChangeSampleRate(nn.Module):
|
32 |
+
def __init__(self, input_rate: int, output_rate: int):
|
33 |
+
super().__init__()
|
34 |
+
self.output_rate = output_rate
|
35 |
+
self.input_rate = input_rate
|
36 |
+
|
37 |
+
def forward(self, wav: torch.tensor) -> torch.tensor:
|
38 |
+
# Only accepts 1-channel waveform input
|
39 |
+
wav = wav.view(wav.size(0), -1)
|
40 |
+
new_length = wav.size(-1) * self.output_rate // self.input_rate
|
41 |
+
indices = torch.arange(new_length) * (
|
42 |
+
self.input_rate / self.output_rate
|
43 |
+
)
|
44 |
+
round_down = wav[:, indices.long()]
|
45 |
+
round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
|
46 |
+
output = round_down * (1.0 - indices.fmod(1.0)).unsqueeze(
|
47 |
+
0
|
48 |
+
) + round_up * indices.fmod(1.0).unsqueeze(0)
|
49 |
+
return output
|
50 |
+
|
51 |
+
|
52 |
+
model = lightning_module.BaselineLightningModule.load_from_checkpoint(
|
53 |
+
"./src/epoch=3-step=7459.ckpt"
|
54 |
+
).eval()
|
55 |
+
|
56 |
+
|
57 |
+
def calc_wer(audio_path, ref, ASR_pipeline):
|
58 |
+
wav, sr = torchaudio.load(audio_path)
|
59 |
+
osr = 16_000
|
60 |
+
batch = wav.unsqueeze(0).repeat(10, 1, 1)
|
61 |
+
csr = ChangeSampleRate(sr, osr)
|
62 |
+
out_wavs = csr(wav)
|
63 |
+
# ASR
|
64 |
+
trans = ASR_pipeline(audio_path)["text"]
|
65 |
+
# WER
|
66 |
+
wer = jiwer.wer(
|
67 |
+
ref,
|
68 |
+
trans,
|
69 |
+
truth_transform=transformation,
|
70 |
+
hypothesis_transform=transformation,
|
71 |
+
)
|
72 |
+
return trans, wer
|
73 |
+
|
74 |
+
if __name__ == "__main__":
|
75 |
+
# Argparse
|
76 |
+
parser = argparse.ArgumentParser(
|
77 |
+
prog="get_ref_PPM",
|
78 |
+
description="Generate Phoneme per Minute (and Voice/Unvoice plot)",
|
79 |
+
epilog="",
|
80 |
+
)
|
81 |
+
parser.add_argument(
|
82 |
+
"--tag",
|
83 |
+
type=str,
|
84 |
+
default=None,
|
85 |
+
required=False,
|
86 |
+
help="ID tag for output *.csv",
|
87 |
+
)
|
88 |
+
|
89 |
+
parser.add_argument("--ref_txt", type=str, required=True, help="Reference TXT")
|
90 |
+
parser.add_argument(
|
91 |
+
"--ref_wavs", type=str, required=True, help="Reference WAVs"
|
92 |
+
)
|
93 |
+
parser.add_argument(
|
94 |
+
"--metadata",
|
95 |
+
type=str,
|
96 |
+
required=False,
|
97 |
+
help="metadata.csv including wav_id and reference",
|
98 |
+
)
|
99 |
+
|
100 |
+
parser.add_argument(
|
101 |
+
"--model",
|
102 |
+
type=str,
|
103 |
+
default='whisper-medium-FT',
|
104 |
+
choices=['wav2vec+ctc', 'whipser-medium-FT', 'whipser-large-v2'],
|
105 |
+
help="ASR engine for evaluation:\n ver1: wav2vec+ctc \n ver2: whipser-medium(Fined-tuned)\n ver3: whipser-large-v2",
|
106 |
+
)
|
107 |
+
|
108 |
+
parser.add_argument(
|
109 |
+
"--output_dir",
|
110 |
+
type=str,
|
111 |
+
required=True,
|
112 |
+
help="Output Directory for *.csv",
|
113 |
+
)
|
114 |
+
|
115 |
+
parser.add_argument(
|
116 |
+
"--to_config",
|
117 |
+
choices=["True", "False"],
|
118 |
+
default="False",
|
119 |
+
help="Generating Config from .txt and wavs/*wav",
|
120 |
+
)
|
121 |
+
|
122 |
+
|
123 |
+
args = parser.parse_args()
|
124 |
+
|
125 |
+
refs = np.loadtxt(args.ref_txt, delimiter="\n", dtype="str")
|
126 |
+
refs_ids = [x.split()[0] for x in refs]
|
127 |
+
refs_txt = [" ".join(x.split()[1:]) for x in refs]
|
128 |
+
ref_wavs = [str(x) for x in sorted(Path(args.ref_wavs).glob("**/*.wav"))]
|
129 |
+
# pdb.set_trace()
|
130 |
+
try:
|
131 |
+
len(refs) == len(ref_wavs)
|
132 |
+
except ValueError:
|
133 |
+
print("Error: Text and Wavs don't match")
|
134 |
+
exit()
|
135 |
+
|
136 |
+
# ASR part
|
137 |
+
if args.model== "whisper-medium-FT":
|
138 |
+
ASR_pipeline = pipeline("automatic-speech-recognition", model="KevinGeng/whipser_medium_en_PAL300_step25")
|
139 |
+
elif args.model == "wav2vec+ctc":
|
140 |
+
ASR_pipeline = pipeline("automatic-speech-recognition")
|
141 |
+
elif args.model == "whisper-large-v2":
|
142 |
+
ASR_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-large-v2")
|
143 |
+
|
144 |
+
# pdb.set_trace()
|
145 |
+
# WER part
|
146 |
+
transformation = jiwer.Compose(
|
147 |
+
[
|
148 |
+
jiwer.ToLowerCase(),
|
149 |
+
jiwer.RemoveWhiteSpace(replace_by_space=True),
|
150 |
+
jiwer.RemoveMultipleSpaces(),
|
151 |
+
jiwer.ReduceToListOfListOfWords(word_delimiter=" "),
|
152 |
+
]
|
153 |
+
)
|
154 |
+
|
155 |
+
# WPM part
|
156 |
+
processor = Wav2Vec2Processor.from_pretrained(
|
157 |
+
"facebook/wav2vec2-xlsr-53-espeak-cv-ft"
|
158 |
+
)
|
159 |
+
phoneme_model = Wav2Vec2ForCTC.from_pretrained(
|
160 |
+
"facebook/wav2vec2-xlsr-53-espeak-cv-ft"
|
161 |
+
)
|
162 |
+
# phoneme_model = pipeline(model="facebook/wav2vec2-xlsr-53-espeak-cv-ft")
|
163 |
+
|
164 |
+
description = """
|
165 |
+
MOS prediction demo using UTMOS-strong w/o phoneme encoder model, \
|
166 |
+
which is trained on the main track dataset.
|
167 |
+
This demo only accepts .wav format. Best at 16 kHz sampling rate.
|
168 |
+
|
169 |
+
Paper is available [here](https://arxiv.org/abs/2204.02152)
|
170 |
+
|
171 |
+
Add ASR based on wav2vec-960, currently only English available.
|
172 |
+
Add WER interface.
|
173 |
+
"""
|
174 |
+
|
175 |
+
referance_id = gr.Textbox(
|
176 |
+
value="ID", placeholder="Utter ID", label="Reference_ID"
|
177 |
+
)
|
178 |
+
referance_textbox = gr.Textbox(
|
179 |
+
value="", placeholder="Input reference here", label="Reference"
|
180 |
+
)
|
181 |
+
# Set up interface
|
182 |
+
result = []
|
183 |
+
result.append("id,ref,hyp,wer")
|
184 |
+
|
185 |
+
|
186 |
+
for id, x, y in track(
|
187 |
+
zip(refs_ids, ref_wavs, refs_txt),
|
188 |
+
total=len(refs_ids),
|
189 |
+
description="Loading references information",
|
190 |
+
):
|
191 |
+
trans, wer = calc_wer(x, y, ASR_pipeline=ASR_pipeline)
|
192 |
+
record = ",".join(
|
193 |
+
[
|
194 |
+
id,
|
195 |
+
str(y),
|
196 |
+
str(trans),
|
197 |
+
str(wer)
|
198 |
+
]
|
199 |
+
)
|
200 |
+
result.append(record)
|
201 |
+
|
202 |
+
# Output
|
203 |
+
if args.tag == None:
|
204 |
+
args.tag = Path(args.ref_wavs).stem
|
205 |
+
# Make output_dir
|
206 |
+
# pdb.set_trace()
|
207 |
+
Path.mkdir(Path(args.output_dir), exist_ok=True)
|
208 |
+
# pdb.set_trace()
|
209 |
+
with open("%s/%s.csv" % (args.output_dir, args.tag), "w") as f:
|
210 |
+
print("\n".join(result), file=f)
|
211 |
+
|
212 |
+
# Generating config
|
213 |
+
if args.to_config == "True":
|
214 |
+
config_dict = {
|
215 |
+
"exp_id": args.tag,
|
216 |
+
"ref_txt": args.ref_txt,
|
217 |
+
"ref_feature": "%s/%s.csv" % (args.output_dir, args.tag),
|
218 |
+
"ref_wavs": args.ref_wavs,
|
219 |
+
"thre": {
|
220 |
+
"minppm": 100,
|
221 |
+
"maxppm": 100,
|
222 |
+
"WER": 0.1,
|
223 |
+
"AUTOMOS": 4.0,
|
224 |
+
},
|
225 |
+
"auth": {"username": None, "password": None},
|
226 |
+
}
|
227 |
+
with open("./config/%s.yaml" % args.tag, "w") as config_f:
|
228 |
+
rprint("Dumping as config ./config/%s.yaml" % args.tag)
|
229 |
+
rprint(config_dict)
|
230 |
+
yaml.dump(config_dict, stream=config_f)
|
231 |
+
rprint("Change parameter ./config/%s.yaml if necessary" % args.tag)
|
232 |
+
print("Reference Dumping Finished")
|
local/get_ref_PPM.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Get Transcription, WER and PPM
|
2 |
+
"""
|
3 |
+
TODO:
|
4 |
+
[DONE]: Automatic generating Config
|
5 |
+
"""
|
6 |
+
|
7 |
+
import yaml
|
8 |
+
import argparse
|
9 |
+
import sys
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
sys.path.append("./src")
|
13 |
+
import lightning_module
|
14 |
+
from UV import plot_UV, get_speech_interval
|
15 |
+
from transformers import pipeline
|
16 |
+
from rich.progress import track
|
17 |
+
from rich import print as rprint
|
18 |
+
import numpy as np
|
19 |
+
import jiwer
|
20 |
+
import pdb
|
21 |
+
import torch.nn as nn
|
22 |
+
import torch
|
23 |
+
import torchaudio
|
24 |
+
import gradio as gr
|
25 |
+
from sys import flags
|
26 |
+
from random import sample
|
27 |
+
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
28 |
+
|
29 |
+
# root_path = Path(__file__).parents[1]
|
30 |
+
|
31 |
+
class ChangeSampleRate(nn.Module):
|
32 |
+
def __init__(self, input_rate: int, output_rate: int):
|
33 |
+
super().__init__()
|
34 |
+
self.output_rate = output_rate
|
35 |
+
self.input_rate = input_rate
|
36 |
+
|
37 |
+
def forward(self, wav: torch.tensor) -> torch.tensor:
|
38 |
+
# Only accepts 1-channel waveform input
|
39 |
+
wav = wav.view(wav.size(0), -1)
|
40 |
+
new_length = wav.size(-1) * self.output_rate // self.input_rate
|
41 |
+
indices = torch.arange(new_length) * (
|
42 |
+
self.input_rate / self.output_rate
|
43 |
+
)
|
44 |
+
round_down = wav[:, indices.long()]
|
45 |
+
round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
|
46 |
+
output = round_down * (1.0 - indices.fmod(1.0)).unsqueeze(
|
47 |
+
0
|
48 |
+
) + round_up * indices.fmod(1.0).unsqueeze(0)
|
49 |
+
return output
|
50 |
+
|
51 |
+
|
52 |
+
model = lightning_module.BaselineLightningModule.load_from_checkpoint(
|
53 |
+
"./src/epoch=3-step=7459.ckpt"
|
54 |
+
).eval()
|
55 |
+
|
56 |
+
|
57 |
+
def calc_mos(audio_path, ref):
|
58 |
+
wav, sr = torchaudio.load(audio_path)
|
59 |
+
osr = 16_000
|
60 |
+
batch = wav.unsqueeze(0).repeat(10, 1, 1)
|
61 |
+
csr = ChangeSampleRate(sr, osr)
|
62 |
+
out_wavs = csr(wav)
|
63 |
+
# ASR
|
64 |
+
trans = p(audio_path)["text"]
|
65 |
+
# WER
|
66 |
+
wer = jiwer.wer(
|
67 |
+
ref,
|
68 |
+
trans,
|
69 |
+
truth_transform=transformation,
|
70 |
+
hypothesis_transform=transformation,
|
71 |
+
)
|
72 |
+
# MOS
|
73 |
+
batch = {
|
74 |
+
"wav": out_wavs,
|
75 |
+
"domains": torch.tensor([0]),
|
76 |
+
"judge_id": torch.tensor([288]),
|
77 |
+
}
|
78 |
+
with torch.no_grad():
|
79 |
+
output = model(batch)
|
80 |
+
predic_mos = output.mean(dim=1).squeeze().detach().numpy() * 2 + 3
|
81 |
+
# Phonemes per minute (PPM)
|
82 |
+
with torch.no_grad():
|
83 |
+
logits = phoneme_model(out_wavs).logits
|
84 |
+
phone_predicted_ids = torch.argmax(logits, dim=-1)
|
85 |
+
phone_transcription = processor.batch_decode(phone_predicted_ids)
|
86 |
+
lst_phonemes = phone_transcription[0].split(" ")
|
87 |
+
wav_vad = torchaudio.functional.vad(wav, sample_rate=sr)
|
88 |
+
ppm = len(lst_phonemes) / (wav_vad.shape[-1] / sr) * 60
|
89 |
+
# if float(predic_mos) >= 3.0:
|
90 |
+
# torchaudio.save("good.wav", wav,sr)
|
91 |
+
|
92 |
+
return predic_mos, trans, wer, phone_transcription, ppm
|
93 |
+
|
94 |
+
if __name__ == "__main__":
|
95 |
+
# Argparse
|
96 |
+
parser = argparse.ArgumentParser(
|
97 |
+
prog="get_ref_PPM",
|
98 |
+
description="Generate Phoneme per Minute (and Voice/Unvoice plot)",
|
99 |
+
epilog="",
|
100 |
+
)
|
101 |
+
parser.add_argument(
|
102 |
+
"--tag",
|
103 |
+
type=str,
|
104 |
+
default=None,
|
105 |
+
required=False,
|
106 |
+
help="ID tag for output *.csv",
|
107 |
+
)
|
108 |
+
|
109 |
+
parser.add_argument("--ref_txt", type=str, required=True, help="Reference TXT")
|
110 |
+
parser.add_argument(
|
111 |
+
"--ref_wavs", type=str, required=True, help="Reference WAVs"
|
112 |
+
)
|
113 |
+
|
114 |
+
parser.add_argument(
|
115 |
+
"--output_dir",
|
116 |
+
type=str,
|
117 |
+
required=True,
|
118 |
+
help="Output Directory for *.csv",
|
119 |
+
)
|
120 |
+
parser.add_argument(
|
121 |
+
"--to_config",
|
122 |
+
choices=["True", "False"],
|
123 |
+
default="False",
|
124 |
+
help="Generating Config from .txt and wavs/*wav",
|
125 |
+
)
|
126 |
+
|
127 |
+
parser.add_argument(
|
128 |
+
"--UV_flag",
|
129 |
+
choices=["True", "False"],
|
130 |
+
default="False",
|
131 |
+
help="Toggle for U/V plot",
|
132 |
+
)
|
133 |
+
parser.add_argument(
|
134 |
+
"--UV_thre", type=float, default=40, help="U/V threshold dB"
|
135 |
+
)
|
136 |
+
args = parser.parse_args()
|
137 |
+
|
138 |
+
refs = np.loadtxt(args.ref_txt, delimiter="\n", dtype="str")
|
139 |
+
refs_ids = [x.split()[0] for x in refs]
|
140 |
+
refs_txt = [" ".join(x.split()[1:]) for x in refs]
|
141 |
+
ref_wavs = [str(x) for x in sorted(Path(args.ref_wavs).glob("**/*.wav"))]
|
142 |
+
# pdb.set_trace()
|
143 |
+
try:
|
144 |
+
len(refs) == len(ref_wavs)
|
145 |
+
except ValueError:
|
146 |
+
print("Error: Text and Wavs don't match")
|
147 |
+
exit()
|
148 |
+
|
149 |
+
# ASR part
|
150 |
+
p = pipeline("automatic-speech-recognition")
|
151 |
+
|
152 |
+
# WER part
|
153 |
+
transformation = jiwer.Compose(
|
154 |
+
[
|
155 |
+
jiwer.ToLowerCase(),
|
156 |
+
jiwer.RemoveWhiteSpace(replace_by_space=True),
|
157 |
+
jiwer.RemoveMultipleSpaces(),
|
158 |
+
jiwer.ReduceToListOfListOfWords(word_delimiter=" "),
|
159 |
+
]
|
160 |
+
)
|
161 |
+
|
162 |
+
# WPM part
|
163 |
+
processor = Wav2Vec2Processor.from_pretrained(
|
164 |
+
"facebook/wav2vec2-xlsr-53-espeak-cv-ft"
|
165 |
+
)
|
166 |
+
phoneme_model = Wav2Vec2ForCTC.from_pretrained(
|
167 |
+
"facebook/wav2vec2-xlsr-53-espeak-cv-ft"
|
168 |
+
)
|
169 |
+
# phoneme_model = pipeline(model="facebook/wav2vec2-xlsr-53-espeak-cv-ft")
|
170 |
+
|
171 |
+
description = """
|
172 |
+
MOS prediction demo using UTMOS-strong w/o phoneme encoder model, \
|
173 |
+
which is trained on the main track dataset.
|
174 |
+
This demo only accepts .wav format. Best at 16 kHz sampling rate.
|
175 |
+
|
176 |
+
Paper is available [here](https://arxiv.org/abs/2204.02152)
|
177 |
+
|
178 |
+
Add ASR based on wav2vec-960, currently only English available.
|
179 |
+
Add WER interface.
|
180 |
+
"""
|
181 |
+
|
182 |
+
referance_id = gr.Textbox(
|
183 |
+
value="ID", placeholder="Utter ID", label="Reference_ID"
|
184 |
+
)
|
185 |
+
referance_textbox = gr.Textbox(
|
186 |
+
value="", placeholder="Input reference here", label="Reference"
|
187 |
+
)
|
188 |
+
# Set up interface
|
189 |
+
result = []
|
190 |
+
result.append("id, pred_mos, trans, wer, pred_phone, ppm")
|
191 |
+
|
192 |
+
if args.UV_flag == "False":
|
193 |
+
for id, x, y in track(
|
194 |
+
zip(refs_ids, ref_wavs, refs_txt),
|
195 |
+
total=len(refs_ids),
|
196 |
+
description="Loading references information",
|
197 |
+
):
|
198 |
+
predic_mos, trans, wer, phone_transcription, ppm = calc_mos(x, y)
|
199 |
+
record = ",".join(
|
200 |
+
[
|
201 |
+
id,
|
202 |
+
str(predic_mos),
|
203 |
+
str(trans),
|
204 |
+
str(wer),
|
205 |
+
str(phone_transcription),
|
206 |
+
str(ppm),
|
207 |
+
]
|
208 |
+
)
|
209 |
+
result.append(record)
|
210 |
+
|
211 |
+
elif args.UV_flag == "True":
|
212 |
+
fig_tardir = Path(args.ref_wavs) / Path("PPM_figs")
|
213 |
+
Path.mkdir(Path(args.ref_wavs) / Path("PPM_figs"), exist_ok=True)
|
214 |
+
|
215 |
+
for id, x, y in track(
|
216 |
+
zip(refs_ids, ref_wavs, refs_txt),
|
217 |
+
total=len(refs_ids),
|
218 |
+
description="Loading references information",
|
219 |
+
):
|
220 |
+
# UV ploting
|
221 |
+
wav, sr = torchaudio.load(x)
|
222 |
+
wav_vad = torchaudio.functional.vad(wav, sample_rate=sr)
|
223 |
+
a_h, p_h = get_speech_interval(wav_vad.numpy(), db=args.UV_thre)
|
224 |
+
fig_h = plot_UV(wav_vad.numpy().squeeze(), a_h, sr=sr)
|
225 |
+
fig_h.savefig(Path(fig_tardir) / Path(id + ".png"), dpi=200)
|
226 |
+
# Acoustic calculation
|
227 |
+
predic_mos, trans, wer, phone_transcription, ppm = calc_mos(x, y)
|
228 |
+
record = ",".join(
|
229 |
+
[
|
230 |
+
id,
|
231 |
+
str(predic_mos),
|
232 |
+
str(trans),
|
233 |
+
str(wer),
|
234 |
+
str(phone_transcription),
|
235 |
+
str(ppm),
|
236 |
+
]
|
237 |
+
)
|
238 |
+
result.append(record)
|
239 |
+
# Output
|
240 |
+
if args.tag == None:
|
241 |
+
args.tag = Path(args.ref_wavs).stem
|
242 |
+
# Make output_dir
|
243 |
+
# pdb.set_trace()
|
244 |
+
Path.mkdir(Path(args.output_dir), exist_ok=True)
|
245 |
+
# pdb.set_trace()
|
246 |
+
with open("%s/%s.csv" % (args.output_dir, args.tag), "w") as f:
|
247 |
+
print("\n".join(result), file=f)
|
248 |
+
|
249 |
+
# Generating config
|
250 |
+
if args.to_config == "True":
|
251 |
+
config_dict = {
|
252 |
+
"exp_id": args.tag,
|
253 |
+
"ref_txt": args.ref_txt,
|
254 |
+
"ref_feature": "%s/%s.csv" % (args.output_dir, args.tag),
|
255 |
+
"ref_wavs": args.ref_wavs,
|
256 |
+
"thre": {
|
257 |
+
"minppm": 100,
|
258 |
+
"maxppm": 100,
|
259 |
+
"WER": 0.1,
|
260 |
+
"AUTOMOS": 4.0,
|
261 |
+
},
|
262 |
+
"auth": {"username": None, "password": None},
|
263 |
+
}
|
264 |
+
with open("./config/%s.yaml" % args.tag, "w") as config_f:
|
265 |
+
rprint("Dumping as config ./config/%s.yaml" % args.tag)
|
266 |
+
rprint(config_dict)
|
267 |
+
yaml.dump(config_dict, stream=config_f)
|
268 |
+
rprint("Change parameter ./config/%s.yaml if necessary" % args.tag)
|
269 |
+
print("Reference Dumping Finished")
|
local/new_whisper_fine_tuning.py
ADDED
@@ -0,0 +1,481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fine_tuning_dir = "/fine_tuned/whipser_medium_en_PAL300_step25_step2_VCTK/checkpoint-400"
|
2 |
+
|
3 |
+
"""
|
4 |
+
TODO:
|
5 |
+
+ [ ] Data load
|
6 |
+
+ [ ] Train / Test / Dev spilt
|
7 |
+
+ [ ] Train / Test Phase
|
8 |
+
+ [ ] Logging with Train / Dev / Test Loss
|
9 |
+
+ [ ] Evalutation metrics
|
10 |
+
"""
|
11 |
+
import pdb
|
12 |
+
import string
|
13 |
+
from pathlib import Path
|
14 |
+
|
15 |
+
import evaluate
|
16 |
+
import librosa
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
from datasets import Dataset, concatenate_datasets, load_dataset
|
20 |
+
from transformers import AutoProcessor
|
21 |
+
|
22 |
+
wer = evaluate.load("wer")
|
23 |
+
torch.cuda.set_device("cuda:0")
|
24 |
+
|
25 |
+
audio_dir = "./data/Patient_sil_trim_16k_normed_5_snr_40"
|
26 |
+
healthy_dir = "./data/Healthy"
|
27 |
+
Fary_PAL_30 = "./data/Fary_PAL_p326_20230110_30"
|
28 |
+
John_p326 = "./data/John_p326/output"
|
29 |
+
John_video = "./data/20230103_video"
|
30 |
+
|
31 |
+
## train
|
32 |
+
p326_300_dir = "./data/John_p326_large"
|
33 |
+
P1tony_arthur = "data/Participant1_Tony_Recording/CLEAN_SENTENCES/SCRIPTED/Arthur_the_Rat/PAL"
|
34 |
+
P1tony_rainbow = "data/Participant1_Tony_Recording/CLEAN_SENTENCES/SCRIPTED/Rainbow_Passage/Laronix"
|
35 |
+
|
36 |
+
P1tony = "data/Participant1_Tony_Recording/CLEAN_SENTENCES/CONVERSATIONAL/PAL"
|
37 |
+
|
38 |
+
P4Negel = 'data/4_negal_152_clean_all'
|
39 |
+
|
40 |
+
def dataclean(example):
|
41 |
+
if example["audio"]["sampling_rate"] != 16000:
|
42 |
+
resampled_audio = librosa.resample(
|
43 |
+
y=example["audio"]["array"],
|
44 |
+
orig_sr=example["audio"]["sampling_rate"],
|
45 |
+
target_sr=16000,
|
46 |
+
)
|
47 |
+
|
48 |
+
return {
|
49 |
+
"audio": {
|
50 |
+
"path": example["audio"]["path"],
|
51 |
+
"array": resampled_audio,
|
52 |
+
"sampling_rate": 16000,
|
53 |
+
},
|
54 |
+
"transcription": example["transcription"]
|
55 |
+
.upper()
|
56 |
+
.translate(str.maketrans("", "", string.punctuation)),
|
57 |
+
}
|
58 |
+
else:
|
59 |
+
return {
|
60 |
+
"transcription": example["transcription"]
|
61 |
+
.upper()
|
62 |
+
.translate(str.maketrans("", "", string.punctuation))
|
63 |
+
}
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
P1tony_dataset = load_dataset("audiofolder", data_dir=P1tony, split="train")
|
68 |
+
P1tony_dataset = P1tony_dataset.map(dataclean)
|
69 |
+
|
70 |
+
P1tony_scripted1 = load_dataset(
|
71 |
+
"audiofolder", data_dir=P1tony_rainbow, split="train"
|
72 |
+
)
|
73 |
+
P1tony_scripted2 = load_dataset(
|
74 |
+
"audiofolder", data_dir=P1tony_arthur, split="train"
|
75 |
+
)
|
76 |
+
P1tony_scripted1 = P1tony_scripted1.map(dataclean)
|
77 |
+
P1tony_scripted2 = P1tony_scripted2.map(dataclean)
|
78 |
+
P1tony_scripted = concatenate_datasets([P1tony_scripted1, P1tony_scripted2])
|
79 |
+
|
80 |
+
class ChangeSampleRate(nn.Module):
|
81 |
+
def __init__(self, input_rate: int, output_rate: int):
|
82 |
+
super().__init__()
|
83 |
+
self.output_rate = output_rate
|
84 |
+
self.input_rate = input_rate
|
85 |
+
|
86 |
+
def forward(self, wav: torch.tensor) -> torch.tensor:
|
87 |
+
# Only accepts 1-channel waveform input
|
88 |
+
wav = wav.view(wav.size(0), -1)
|
89 |
+
new_length = wav.size(-1) * self.output_rate // self.input_rate
|
90 |
+
indices = torch.arange(new_length) * (
|
91 |
+
self.input_rate / self.output_rate
|
92 |
+
)
|
93 |
+
round_down = wav[:, indices.long()]
|
94 |
+
round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
|
95 |
+
output = round_down * (1.0 - indices.fmod(1.0)).unsqueeze(
|
96 |
+
0
|
97 |
+
) + round_up * indices.fmod(1.0).unsqueeze(0)
|
98 |
+
return output
|
99 |
+
|
100 |
+
# resample and clean text data
|
101 |
+
def dataclean(example):
|
102 |
+
# pdb.set_trace()
|
103 |
+
if example["audio"]["sampling_rate"] != 16000:
|
104 |
+
resampled_audio = librosa.resample(
|
105 |
+
y=example["audio"]["array"],
|
106 |
+
orig_sr=example["audio"]["sampling_rate"],
|
107 |
+
target_sr=16000,
|
108 |
+
)
|
109 |
+
|
110 |
+
return {
|
111 |
+
"audio": {
|
112 |
+
"path": example["audio"]["path"],
|
113 |
+
"array": resampled_audio,
|
114 |
+
"sampling_rate": 16000,
|
115 |
+
},
|
116 |
+
"transcription": example["transcription"]
|
117 |
+
.upper()
|
118 |
+
.translate(str.maketrans("", "", string.punctuation)),
|
119 |
+
}
|
120 |
+
else:
|
121 |
+
return {
|
122 |
+
"transcription": example["transcription"]
|
123 |
+
.upper()
|
124 |
+
.translate(str.maketrans("", "", string.punctuation))
|
125 |
+
}
|
126 |
+
|
127 |
+
|
128 |
+
# processor = AutoFeatureExtractor.from_pretrained(
|
129 |
+
# "facebook/wav2vec2-base-960h"
|
130 |
+
# )
|
131 |
+
processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
|
132 |
+
|
133 |
+
def prepare_dataset(batch):
|
134 |
+
audio = batch["audio"]
|
135 |
+
batch = processor(
|
136 |
+
audio["array"],
|
137 |
+
sampling_rate=audio["sampling_rate"],
|
138 |
+
text=batch["transcription"],
|
139 |
+
)
|
140 |
+
batch["input_length"] = len(batch["input_values"][0])
|
141 |
+
return batch
|
142 |
+
|
143 |
+
src_dataset = load_dataset("audiofolder", data_dir=audio_dir, split="train")
|
144 |
+
src_dataset = src_dataset.map(dataclean)
|
145 |
+
p326_300_dataset = load_dataset(
|
146 |
+
"audiofolder", data_dir=p326_300_dir, split="train"
|
147 |
+
)
|
148 |
+
p326_300_dataset = p326_300_dataset.map(dataclean)
|
149 |
+
|
150 |
+
P4Negel_dataset = load_dataset("audiofolder", data_dir=P4Negel, split="train")
|
151 |
+
P4Negel_dataset = P4Negel_dataset.map(dataclean)
|
152 |
+
|
153 |
+
healthy_test_dataset = load_dataset(
|
154 |
+
"audiofolder", data_dir=healthy_dir, split="train"
|
155 |
+
)
|
156 |
+
healthy_test_dataset = healthy_test_dataset.map(dataclean)
|
157 |
+
|
158 |
+
Fary_PAL_test_dataset = load_dataset(
|
159 |
+
"audiofolder", data_dir=Fary_PAL_30, split="train"
|
160 |
+
)
|
161 |
+
Fary_PAL_test_dataset = Fary_PAL_test_dataset.map(dataclean)
|
162 |
+
|
163 |
+
John_p326_test_dataset = load_dataset(
|
164 |
+
"audiofolder", data_dir=John_p326, split="train"
|
165 |
+
)
|
166 |
+
John_p326_test_dataset = John_p326_test_dataset.map(dataclean)
|
167 |
+
|
168 |
+
John_video_test_dataset = load_dataset(
|
169 |
+
"audiofolder", data_dir=John_video, split="train"
|
170 |
+
)
|
171 |
+
John_video_test_dataset = John_video_test_dataset.map(dataclean)
|
172 |
+
|
173 |
+
|
174 |
+
def train_dev_test_split(
|
175 |
+
dataset: Dataset, dev_rate=0.1, test_rate=0.1, seed=1
|
176 |
+
):
|
177 |
+
"""
|
178 |
+
input: dataset
|
179 |
+
dev_rate,
|
180 |
+
test_rate
|
181 |
+
seed
|
182 |
+
-------
|
183 |
+
Output:
|
184 |
+
dataset_dict{"train", "dev", "test"}
|
185 |
+
"""
|
186 |
+
train_dev_test = dataset.train_test_split(test_size=test_rate, seed=seed)
|
187 |
+
test = train_dev_test["test"]
|
188 |
+
train_dev = train_dev_test["train"]
|
189 |
+
|
190 |
+
# pdb.set_trace()
|
191 |
+
if len(train_dev) <= int(len(dataset) * dev_rate):
|
192 |
+
train = Dataset.from_dict({"audio": [], "transcription": []})
|
193 |
+
dev = train_dev
|
194 |
+
else:
|
195 |
+
train_dev = train_dev.train_test_split(
|
196 |
+
test_size=int(len(dataset) * dev_rate), seed=seed
|
197 |
+
)
|
198 |
+
train = train_dev["train"]
|
199 |
+
dev = train_dev["test"]
|
200 |
+
return train, dev, test
|
201 |
+
|
202 |
+
P1tony_train, P1tony_dev, P1tony_test = train_dev_test_split(
|
203 |
+
P1tony_dataset, dev_rate=0.5, test_rate=0.5, seed=1
|
204 |
+
)
|
205 |
+
P1tony_train_ = concatenate_datasets([P1tony_train, P1tony_scripted])
|
206 |
+
|
207 |
+
# train_dev / test
|
208 |
+
ds = src_dataset.train_test_split(test_size=0.1, seed=1)
|
209 |
+
|
210 |
+
# dataset_libri = load_dataset(
|
211 |
+
# "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
|
212 |
+
# )
|
213 |
+
|
214 |
+
train_dev = ds["train"]
|
215 |
+
# train / dev
|
216 |
+
train_dev = train_dev.train_test_split(
|
217 |
+
test_size=int(len(src_dataset) * 0.1), seed=1
|
218 |
+
)
|
219 |
+
|
220 |
+
# Tony
|
221 |
+
Tony_train = P1tony_train_
|
222 |
+
Tony_dev = P1tony_dev
|
223 |
+
Tony_test = P1tony_test
|
224 |
+
|
225 |
+
# John
|
226 |
+
John_train, John_dev, John_test = train_dev_test_split(p326_300_dataset, dev_rate=0.1, test_rate=0.1)
|
227 |
+
# Negel
|
228 |
+
Negel_train, Negel_dev, Negel_test = train_dev_test_split(P4Negel_dataset, dev_rate=0.1, test_rate=0.1)
|
229 |
+
|
230 |
+
# train/dev/test
|
231 |
+
train = train_dev["train"]
|
232 |
+
test = ds["test"]
|
233 |
+
dev = train_dev["test"]
|
234 |
+
|
235 |
+
# combined
|
236 |
+
combine_train = concatenate_datasets([train, Tony_train, John_train, Negel_train])
|
237 |
+
conbine_dev = concatenate_datasets([dev, Tony_dev, John_dev, Negel_dev])
|
238 |
+
conbine_test = concatenate_datasets([test, Tony_test, John_test, Negel_test])
|
239 |
+
|
240 |
+
# encoded_train = combine_train.map(prepare_dataset, num_proc=4)
|
241 |
+
# encoded_dev = conbine_dev.map(prepare_dataset, num_proc=4)
|
242 |
+
# encoded_test = conbine_test.map(prepare_dataset, num_proc=4)
|
243 |
+
|
244 |
+
# # extra_test
|
245 |
+
# encoded_Fary = Fary_PAL_test_dataset.map(prepare_dataset, num_proc=4)
|
246 |
+
# encoded_healthy = healthy_test_dataset.map(prepare_dataset, num_proc=4)
|
247 |
+
|
248 |
+
# encoded_ori_test = test.map(prepare_dataset, num_proc=4)
|
249 |
+
# encoded_Tony_test = Tony_test.map(prepare_dataset, num_proc=4)
|
250 |
+
# encoded_John_test = John_test.map(prepare_dataset, num_proc=4)
|
251 |
+
# encoded_Negel_test = Negel_test.map(prepare_dataset, num_proc=4)
|
252 |
+
|
253 |
+
# encoded_train = train.map(prepare_dataset, num_proc=4)
|
254 |
+
# encoded_dev = dev.map(prepare_dataset, num_proc=4)
|
255 |
+
# p326_encoded_train = p326_300_dataset.map(prepare_dataset, num_proc=4)
|
256 |
+
|
257 |
+
# combine large p326 in to training set
|
258 |
+
# encoded_train = concatenate_datasets([encoded_train, p326_encoded_train])
|
259 |
+
|
260 |
+
# encoded_John_p326 = John_p326_test_dataset.map(prepare_dataset, num_proc=4)
|
261 |
+
# encoded_John_video = John_video_test_dataset.map(prepare_dataset, num_proc=4)
|
262 |
+
|
263 |
+
# pdb.set_trace()
|
264 |
+
import numpy as np
|
265 |
+
|
266 |
+
WER = evaluate.load("wer")
|
267 |
+
|
268 |
+
## Whisper decoding
|
269 |
+
|
270 |
+
from transformers import (Seq2SeqTrainer, Seq2SeqTrainingArguments,
|
271 |
+
WhisperFeatureExtractor,
|
272 |
+
WhisperForConditionalGeneration, WhisperModel,
|
273 |
+
WhisperProcessor, WhisperTokenizer)
|
274 |
+
|
275 |
+
processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
|
276 |
+
# model = WhisperForConditionalGeneration.from_pretrained(
|
277 |
+
# "./fine_tuned/whipser_medium_en_PAL300_step25_step2_VCTK/checkpoint-400",
|
278 |
+
# use_auth_token=True,
|
279 |
+
# ).to("cuda:0")
|
280 |
+
model = WhisperForConditionalGeneration.from_pretrained(
|
281 |
+
"openai/whisper-medium",
|
282 |
+
).to("cuda:0")
|
283 |
+
tokenizer = WhisperTokenizer.from_pretrained(
|
284 |
+
"openai/whisper-medium", language="English", task="transcribe"
|
285 |
+
)
|
286 |
+
|
287 |
+
from pathlib import Path
|
288 |
+
|
289 |
+
id = Path(fine_tuning_dir).stem
|
290 |
+
pdb.set_trace()
|
291 |
+
tokenizer.push_to_hub("KevinGeng/%s" % id)
|
292 |
+
# import pdb
|
293 |
+
feature_extractor = WhisperFeatureExtractor.from_pretrained(
|
294 |
+
"openai/whisper-medium"
|
295 |
+
)
|
296 |
+
|
297 |
+
def whisper_prepare_dataset(batch):
|
298 |
+
# load and resample audio data from 48 to 16kHz
|
299 |
+
audio = batch["audio"]
|
300 |
+
|
301 |
+
# compute log-Mel input features from input audio array
|
302 |
+
batch["input_features"] = feature_extractor(
|
303 |
+
audio["array"], sampling_rate=audio["sampling_rate"]
|
304 |
+
).input_features[0]
|
305 |
+
|
306 |
+
# encode target text to label ids
|
307 |
+
batch["labels"] = tokenizer(batch["transcription"]).input_ids
|
308 |
+
return batch
|
309 |
+
|
310 |
+
torch.cuda.empty_cache()
|
311 |
+
|
312 |
+
|
313 |
+
def my_map_to_pred(batch):
|
314 |
+
# pdb.set_trace()
|
315 |
+
audio = batch["audio"]
|
316 |
+
input_features = processor(
|
317 |
+
audio["array"],
|
318 |
+
sampling_rate=audio["sampling_rate"],
|
319 |
+
return_tensors="pt",
|
320 |
+
).input_features
|
321 |
+
# batch["reference"] = whisper_processor.tokenizer._normalize(batch['text'])
|
322 |
+
batch["reference"] = processor.tokenizer._normalize(batch["transcription"])
|
323 |
+
|
324 |
+
with torch.no_grad():
|
325 |
+
# predicted_ids = whisper_model.generate(input_features.to("cuda"))[0]
|
326 |
+
predicted_ids = model.generate(input_features.to("cuda"))[0]
|
327 |
+
transcription = model.decode(predicted_ids)
|
328 |
+
batch["prediction"] = model.tokenizer._normalize(transcription)
|
329 |
+
return batch
|
330 |
+
|
331 |
+
|
332 |
+
from dataclasses import dataclass
|
333 |
+
from typing import Any, Dict, List, Union
|
334 |
+
|
335 |
+
import torch
|
336 |
+
|
337 |
+
|
338 |
+
@dataclass
|
339 |
+
class DataCollatorSpeechSeq2SeqWithPadding:
|
340 |
+
processor: Any
|
341 |
+
|
342 |
+
def __call__(
|
343 |
+
self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
|
344 |
+
) -> Dict[str, torch.Tensor]:
|
345 |
+
# split inputs and labels since they have to be of different lengths and need different padding methods
|
346 |
+
# first treat the audio inputs by simply returning torch tensors
|
347 |
+
input_features = [
|
348 |
+
{"input_features": feature["input_features"]}
|
349 |
+
for feature in features
|
350 |
+
]
|
351 |
+
batch = self.processor.feature_extractor.pad(
|
352 |
+
input_features, return_tensors="pt"
|
353 |
+
)
|
354 |
+
|
355 |
+
# get the tokenized label sequences
|
356 |
+
label_features = [
|
357 |
+
{"input_ids": feature["labels"]} for feature in features
|
358 |
+
]
|
359 |
+
# pad the labels to max length
|
360 |
+
labels_batch = self.processor.tokenizer.pad(
|
361 |
+
label_features, return_tensors="pt"
|
362 |
+
)
|
363 |
+
|
364 |
+
# replace padding with -100 to ignore loss correctly
|
365 |
+
labels = labels_batch["input_ids"].masked_fill(
|
366 |
+
labels_batch.attention_mask.ne(1), -100
|
367 |
+
)
|
368 |
+
|
369 |
+
# if bos token is appended in previous tokenization step,
|
370 |
+
# cut bos token here as it's append later anyways
|
371 |
+
if (
|
372 |
+
(labels[:, 0] == self.processor.tokenizer.bos_token_id)
|
373 |
+
.all()
|
374 |
+
.cpu()
|
375 |
+
.item()
|
376 |
+
):
|
377 |
+
labels = labels[:, 1:]
|
378 |
+
|
379 |
+
batch["labels"] = labels
|
380 |
+
|
381 |
+
return batch
|
382 |
+
|
383 |
+
|
384 |
+
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
|
385 |
+
|
386 |
+
|
387 |
+
def compute_metrics(pred):
|
388 |
+
pred_ids = pred.predictions
|
389 |
+
label_ids = pred.label_ids
|
390 |
+
|
391 |
+
# replace -100 with the pad_token_id
|
392 |
+
label_ids[label_ids == -100] = tokenizer.pad_token_id
|
393 |
+
|
394 |
+
# we do not want to group tokens when computing the metrics
|
395 |
+
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
396 |
+
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
|
397 |
+
|
398 |
+
wer = 100 * WER.compute(predictions=pred_str, references=label_str)
|
399 |
+
|
400 |
+
return {"wer": wer}
|
401 |
+
|
402 |
+
encoded_train = combine_train.map(whisper_prepare_dataset, num_proc=4)
|
403 |
+
encoded_dev = conbine_dev.map(whisper_prepare_dataset, num_proc=4)
|
404 |
+
encoded_test = conbine_test.map(whisper_prepare_dataset, num_proc=4)
|
405 |
+
|
406 |
+
# extra_test
|
407 |
+
|
408 |
+
encoded_ori_test = test.map(whisper_prepare_dataset, num_proc=4)
|
409 |
+
encoded_Tony_test = Tony_test.map(whisper_prepare_dataset, num_proc=4)
|
410 |
+
encoded_John_test = John_test.map(whisper_prepare_dataset, num_proc=4)
|
411 |
+
encoded_Negel_test = Negel_test.map(whisper_prepare_dataset, num_proc=4)
|
412 |
+
|
413 |
+
encoded_Fary = Fary_PAL_test_dataset.map(whisper_prepare_dataset, num_proc=4)
|
414 |
+
encoded_healthy = healthy_test_dataset.map(whisper_prepare_dataset, num_proc=4)
|
415 |
+
|
416 |
+
torch.cuda.empty_cache()
|
417 |
+
|
418 |
+
training_args = Seq2SeqTrainingArguments(
|
419 |
+
output_dir=fine_tuning_dir, # change to a repo name of your choice
|
420 |
+
per_device_train_batch_size=8,
|
421 |
+
gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
|
422 |
+
learning_rate=1e-5,
|
423 |
+
warmup_steps=50,
|
424 |
+
max_steps=1000,
|
425 |
+
gradient_checkpointing=True,
|
426 |
+
fp16=True,
|
427 |
+
evaluation_strategy="steps",
|
428 |
+
save_strategy="steps",
|
429 |
+
per_device_eval_batch_size=8,
|
430 |
+
predict_with_generate=True,
|
431 |
+
generation_max_length=512,
|
432 |
+
save_steps=20,
|
433 |
+
eval_steps=20,
|
434 |
+
logging_steps=10,
|
435 |
+
report_to=["tensorboard"],
|
436 |
+
load_best_model_at_end=True,
|
437 |
+
metric_for_best_model="wer",
|
438 |
+
greater_is_better=False,
|
439 |
+
save_total_limit=5,
|
440 |
+
push_to_hub=False,
|
441 |
+
)
|
442 |
+
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
443 |
+
|
444 |
+
trainer = Seq2SeqTrainer(
|
445 |
+
args=training_args,
|
446 |
+
model=model,
|
447 |
+
train_dataset=Negel_train,
|
448 |
+
eval_dataset=Negel_dev,
|
449 |
+
data_collator=data_collator,
|
450 |
+
compute_metrics=compute_metrics,
|
451 |
+
tokenizer=processor.feature_extractor,
|
452 |
+
callbacks=[EarlyStoppingCallback(early_stopping_patience=10)],
|
453 |
+
)
|
454 |
+
# callbacks=[EvalLoggingCallback()]
|
455 |
+
pdb.set_trace()
|
456 |
+
|
457 |
+
before_result_dict = {
|
458 |
+
"Ori_Test": trainer.evaluate(encoded_ori_test),
|
459 |
+
"Tony_Test": trainer.evaluate(encoded_Tony_test),
|
460 |
+
"John_Test": trainer.evaluate(encoded_John_test),
|
461 |
+
"Negel_Test": trainer.evaluate(encoded_Negel_test),
|
462 |
+
"Zeroshot_Fary_Test": trainer.evaluate(encoded_Fary),
|
463 |
+
"Healthy_Test": trainer.evaluate(encoded_healthy),
|
464 |
+
}
|
465 |
+
|
466 |
+
print(before_result_dict)
|
467 |
+
trainer.train()
|
468 |
+
|
469 |
+
pdb.set_trace()
|
470 |
+
result_dict = {
|
471 |
+
"Ori_Test": trainer.evaluate(encoded_ori_test),
|
472 |
+
"Tony_Test": trainer.evaluate(encoded_Tony_test),
|
473 |
+
"John_Test": trainer.evaluate(encoded_John_test),
|
474 |
+
"Negel_Test": trainer.evaluate(encoded_Negel_test),
|
475 |
+
"Zeroshot_Fary_Test": trainer.evaluate(encoded_Fary),
|
476 |
+
"Healthy_Test": trainer.evaluate(encoded_healthy),
|
477 |
+
}
|
478 |
+
|
479 |
+
pdb.set_trace()
|
480 |
+
# Evaluation
|
481 |
+
model.push_to_hub("KevinGeng/%s" % id)
|
local/new_whisper_fine_tuning_decode.py
ADDED
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fine_tuning_dir = "/home/kevingeng/Disk2/laronix/laronix_automos/fine_tuned/SSD/model/Org_Tony_John_Negel_Train_557_Dev_79_Test_81/checkpoint-160"
|
2 |
+
"""
|
3 |
+
TODO:
|
4 |
+
+ [ ] Data load
|
5 |
+
+ [ ] Train / Test / Dev spilt
|
6 |
+
+ [ ] Train / Test Phase
|
7 |
+
+ [ ] Logging with Train / Dev / Test Loss
|
8 |
+
+ [ ] Evalutation metrics
|
9 |
+
"""
|
10 |
+
import pdb
|
11 |
+
import string
|
12 |
+
from pathlib import Path
|
13 |
+
|
14 |
+
import evaluate
|
15 |
+
import librosa
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
from datasets import Dataset, concatenate_datasets, load_dataset
|
19 |
+
from transformers import AutoProcessor
|
20 |
+
|
21 |
+
wer = evaluate.load("wer")
|
22 |
+
torch.cuda.set_device("cuda:0")
|
23 |
+
|
24 |
+
audio_dir = "./data/Patient_sil_trim_16k_normed_5_snr_40"
|
25 |
+
healthy_dir = "./data/Healthy"
|
26 |
+
Fary_PAL_30 = "./data/Fary_PAL_p326_20230110_30"
|
27 |
+
John_p326 = "./data/John_p326/output"
|
28 |
+
John_video = "./data/20230103_video"
|
29 |
+
|
30 |
+
## train
|
31 |
+
p326_300_dir = "./data/John_p326_large"
|
32 |
+
P1tony_arthur = "data/Participant1_Tony_Recording/CLEAN_SENTENCES/SCRIPTED/Arthur_the_Rat/PAL"
|
33 |
+
P1tony_rainbow = "data/Participant1_Tony_Recording/CLEAN_SENTENCES/SCRIPTED/Rainbow_Passage/Laronix"
|
34 |
+
|
35 |
+
P1tony = "data/Participant1_Tony_Recording/CLEAN_SENTENCES/CONVERSATIONAL/PAL"
|
36 |
+
|
37 |
+
P4Negel = 'data/4_negal_152_clean_all'
|
38 |
+
|
39 |
+
def dataclean(example):
|
40 |
+
if example["audio"]["sampling_rate"] != 16000:
|
41 |
+
resampled_audio = librosa.resample(
|
42 |
+
y=example["audio"]["array"],
|
43 |
+
orig_sr=example["audio"]["sampling_rate"],
|
44 |
+
target_sr=16000,
|
45 |
+
)
|
46 |
+
|
47 |
+
return {
|
48 |
+
"audio": {
|
49 |
+
"path": example["audio"]["path"],
|
50 |
+
"array": resampled_audio,
|
51 |
+
"sampling_rate": 16000,
|
52 |
+
},
|
53 |
+
"transcription": example["transcription"]
|
54 |
+
.upper()
|
55 |
+
.translate(str.maketrans("", "", string.punctuation)),
|
56 |
+
}
|
57 |
+
else:
|
58 |
+
return {
|
59 |
+
"transcription": example["transcription"]
|
60 |
+
.upper()
|
61 |
+
.translate(str.maketrans("", "", string.punctuation))
|
62 |
+
}
|
63 |
+
|
64 |
+
P1tony_dataset = load_dataset("audiofolder", data_dir=P1tony, split="train")
|
65 |
+
P1tony_dataset = P1tony_dataset.map(dataclean)
|
66 |
+
|
67 |
+
P1tony_scripted1 = load_dataset(
|
68 |
+
"audiofolder", data_dir=P1tony_rainbow, split="train"
|
69 |
+
)
|
70 |
+
P1tony_scripted2 = load_dataset(
|
71 |
+
"audiofolder", data_dir=P1tony_arthur, split="train"
|
72 |
+
)
|
73 |
+
P1tony_scripted1 = P1tony_scripted1.map(dataclean)
|
74 |
+
P1tony_scripted2 = P1tony_scripted2.map(dataclean)
|
75 |
+
P1tony_scripted = concatenate_datasets([P1tony_scripted1, P1tony_scripted2])
|
76 |
+
|
77 |
+
class ChangeSampleRate(nn.Module):
|
78 |
+
def __init__(self, input_rate: int, output_rate: int):
|
79 |
+
super().__init__()
|
80 |
+
self.output_rate = output_rate
|
81 |
+
self.input_rate = input_rate
|
82 |
+
|
83 |
+
def forward(self, wav: torch.tensor) -> torch.tensor:
|
84 |
+
# Only accepts 1-channel waveform input
|
85 |
+
wav = wav.view(wav.size(0), -1)
|
86 |
+
new_length = wav.size(-1) * self.output_rate // self.input_rate
|
87 |
+
indices = torch.arange(new_length) * (
|
88 |
+
self.input_rate / self.output_rate
|
89 |
+
)
|
90 |
+
round_down = wav[:, indices.long()]
|
91 |
+
round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
|
92 |
+
output = round_down * (1.0 - indices.fmod(1.0)).unsqueeze(
|
93 |
+
0
|
94 |
+
) + round_up * indices.fmod(1.0).unsqueeze(0)
|
95 |
+
return output
|
96 |
+
|
97 |
+
# resample and clean text data
|
98 |
+
def dataclean(example):
|
99 |
+
# pdb.set_trace()
|
100 |
+
if example["audio"]["sampling_rate"] != 16000:
|
101 |
+
resampled_audio = librosa.resample(
|
102 |
+
y=example["audio"]["array"],
|
103 |
+
orig_sr=example["audio"]["sampling_rate"],
|
104 |
+
target_sr=16000,
|
105 |
+
)
|
106 |
+
|
107 |
+
return {
|
108 |
+
"audio": {
|
109 |
+
"path": example["audio"]["path"],
|
110 |
+
"array": resampled_audio,
|
111 |
+
"sampling_rate": 16000,
|
112 |
+
},
|
113 |
+
"transcription": example["transcription"]
|
114 |
+
.upper()
|
115 |
+
.translate(str.maketrans("", "", string.punctuation)),
|
116 |
+
}
|
117 |
+
else:
|
118 |
+
return {
|
119 |
+
"transcription": example["transcription"]
|
120 |
+
.upper()
|
121 |
+
.translate(str.maketrans("", "", string.punctuation))
|
122 |
+
}
|
123 |
+
|
124 |
+
|
125 |
+
# processor = AutoFeatureExtractor.from_pretrained(
|
126 |
+
# "facebook/wav2vec2-base-960h"
|
127 |
+
# )
|
128 |
+
processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
|
129 |
+
|
130 |
+
def prepare_dataset(batch):
|
131 |
+
audio = batch["audio"]
|
132 |
+
batch = processor(
|
133 |
+
audio["array"],
|
134 |
+
sampling_rate=audio["sampling_rate"],
|
135 |
+
text=batch["transcription"],
|
136 |
+
)
|
137 |
+
batch["input_length"] = len(batch["input_values"][0])
|
138 |
+
return batch
|
139 |
+
|
140 |
+
src_dataset = load_dataset("audiofolder", data_dir=audio_dir, split="train")
|
141 |
+
src_dataset = src_dataset.map(dataclean)
|
142 |
+
p326_300_dataset = load_dataset(
|
143 |
+
"audiofolder", data_dir=p326_300_dir, split="train"
|
144 |
+
)
|
145 |
+
p326_300_dataset = p326_300_dataset.map(dataclean)
|
146 |
+
|
147 |
+
P4Negel_dataset = load_dataset("audiofolder", data_dir=P4Negel, split="train")
|
148 |
+
P4Negel_dataset = P4Negel_dataset.map(dataclean)
|
149 |
+
|
150 |
+
healthy_test_dataset = load_dataset(
|
151 |
+
"audiofolder", data_dir=healthy_dir, split="train"
|
152 |
+
)
|
153 |
+
healthy_test_dataset = healthy_test_dataset.map(dataclean)
|
154 |
+
|
155 |
+
Fary_PAL_test_dataset = load_dataset(
|
156 |
+
"audiofolder", data_dir=Fary_PAL_30, split="train"
|
157 |
+
)
|
158 |
+
Fary_PAL_test_dataset = Fary_PAL_test_dataset.map(dataclean)
|
159 |
+
|
160 |
+
John_p326_test_dataset = load_dataset(
|
161 |
+
"audiofolder", data_dir=John_p326, split="train"
|
162 |
+
)
|
163 |
+
John_p326_test_dataset = John_p326_test_dataset.map(dataclean)
|
164 |
+
|
165 |
+
John_video_test_dataset = load_dataset(
|
166 |
+
"audiofolder", data_dir=John_video, split="train"
|
167 |
+
)
|
168 |
+
John_video_test_dataset = John_video_test_dataset.map(dataclean)
|
169 |
+
|
170 |
+
|
171 |
+
def train_dev_test_split(
|
172 |
+
dataset: Dataset, dev_rate=0.1, test_rate=0.1, seed=1
|
173 |
+
):
|
174 |
+
"""
|
175 |
+
input: dataset
|
176 |
+
dev_rate,
|
177 |
+
test_rate
|
178 |
+
seed
|
179 |
+
-------
|
180 |
+
Output:
|
181 |
+
dataset_dict{"train", "dev", "test"}
|
182 |
+
"""
|
183 |
+
train_dev_test = dataset.train_test_split(test_size=test_rate, seed=seed)
|
184 |
+
test = train_dev_test["test"]
|
185 |
+
train_dev = train_dev_test["train"]
|
186 |
+
|
187 |
+
# pdb.set_trace()
|
188 |
+
if len(train_dev) <= int(len(dataset) * dev_rate):
|
189 |
+
train = Dataset.from_dict({"audio": [], "transcription": []})
|
190 |
+
dev = train_dev
|
191 |
+
else:
|
192 |
+
train_dev = train_dev.train_test_split(
|
193 |
+
test_size=int(len(dataset) * dev_rate), seed=seed
|
194 |
+
)
|
195 |
+
train = train_dev["train"]
|
196 |
+
dev = train_dev["test"]
|
197 |
+
return train, dev, test
|
198 |
+
|
199 |
+
P1tony_train, P1tony_dev, P1tony_test = train_dev_test_split(
|
200 |
+
P1tony_dataset, dev_rate=0.5, test_rate=0.5, seed=1
|
201 |
+
)
|
202 |
+
P1tony_train_ = concatenate_datasets([P1tony_train, P1tony_scripted])
|
203 |
+
|
204 |
+
# train_dev / test
|
205 |
+
ds = src_dataset.train_test_split(test_size=0.1, seed=1)
|
206 |
+
|
207 |
+
# dataset_libri = load_dataset(
|
208 |
+
# "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
|
209 |
+
# )
|
210 |
+
|
211 |
+
train_dev = ds["train"]
|
212 |
+
# train / dev
|
213 |
+
train_dev = train_dev.train_test_split(
|
214 |
+
test_size=int(len(src_dataset) * 0.1), seed=1
|
215 |
+
)
|
216 |
+
|
217 |
+
# Tony
|
218 |
+
Tony_train = P1tony_train_
|
219 |
+
Tony_dev = P1tony_dev
|
220 |
+
Tony_test = P1tony_test
|
221 |
+
|
222 |
+
# John
|
223 |
+
John_train, John_dev, John_test = train_dev_test_split(p326_300_dataset, dev_rate=0.1, test_rate=0.1)
|
224 |
+
# Negel
|
225 |
+
Negel_train, Negel_dev, Negel_test = train_dev_test_split(P4Negel_dataset, dev_rate=0.1, test_rate=0.1)
|
226 |
+
|
227 |
+
# train/dev/test
|
228 |
+
train = train_dev["train"]
|
229 |
+
test = ds["test"]
|
230 |
+
dev = train_dev["test"]
|
231 |
+
|
232 |
+
# combined
|
233 |
+
combine_train = concatenate_datasets([train, Tony_train, John_train, Negel_train])
|
234 |
+
conbine_dev = concatenate_datasets([dev, Tony_dev, John_dev, Negel_dev])
|
235 |
+
conbine_test = concatenate_datasets([test, Tony_test, John_test, Negel_test])
|
236 |
+
|
237 |
+
# encoded_train = combine_train.map(prepare_dataset, num_proc=4)
|
238 |
+
# encoded_dev = conbine_dev.map(prepare_dataset, num_proc=4)
|
239 |
+
# encoded_test = conbine_test.map(prepare_dataset, num_proc=4)
|
240 |
+
|
241 |
+
# # extra_test
|
242 |
+
# encoded_Fary = Fary_PAL_test_dataset.map(prepare_dataset, num_proc=4)
|
243 |
+
# encoded_healthy = healthy_test_dataset.map(prepare_dataset, num_proc=4)
|
244 |
+
|
245 |
+
# encoded_ori_test = test.map(prepare_dataset, num_proc=4)
|
246 |
+
# encoded_Tony_test = Tony_test.map(prepare_dataset, num_proc=4)
|
247 |
+
# encoded_John_test = John_test.map(prepare_dataset, num_proc=4)
|
248 |
+
# encoded_Negel_test = Negel_test.map(prepare_dataset, num_proc=4)
|
249 |
+
|
250 |
+
# encoded_train = train.map(prepare_dataset, num_proc=4)
|
251 |
+
# encoded_dev = dev.map(prepare_dataset, num_proc=4)
|
252 |
+
# p326_encoded_train = p326_300_dataset.map(prepare_dataset, num_proc=4)
|
253 |
+
|
254 |
+
# combine large p326 in to training set
|
255 |
+
# encoded_train = concatenate_datasets([encoded_train, p326_encoded_train])
|
256 |
+
|
257 |
+
# encoded_John_p326 = John_p326_test_dataset.map(prepare_dataset, num_proc=4)
|
258 |
+
# encoded_John_video = John_video_test_dataset.map(prepare_dataset, num_proc=4)
|
259 |
+
|
260 |
+
# pdb.set_trace()
|
261 |
+
import numpy as np
|
262 |
+
|
263 |
+
WER = evaluate.load("wer")
|
264 |
+
|
265 |
+
## Whisper decoding
|
266 |
+
|
267 |
+
from transformers import (Seq2SeqTrainer, Seq2SeqTrainingArguments,
|
268 |
+
WhisperFeatureExtractor,
|
269 |
+
WhisperForConditionalGeneration, WhisperModel,
|
270 |
+
WhisperProcessor, WhisperTokenizer)
|
271 |
+
|
272 |
+
processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
|
273 |
+
model = WhisperForConditionalGeneration.from_pretrained(
|
274 |
+
fine_tuning_dir,
|
275 |
+
).to("cuda:0")
|
276 |
+
# model = WhisperForConditionalGeneration.from_pretrained(
|
277 |
+
# "openai/whisper-medium",
|
278 |
+
# ).to("cuda:0")
|
279 |
+
tokenizer = WhisperTokenizer.from_pretrained(
|
280 |
+
"openai/whisper-medium", language="English", task="transcribe"
|
281 |
+
)
|
282 |
+
|
283 |
+
from pathlib import Path
|
284 |
+
|
285 |
+
id = Path(fine_tuning_dir).stem
|
286 |
+
pdb.set_trace()
|
287 |
+
# tokenizer.push_to_hub("KevinGeng/%s" % id)
|
288 |
+
# import pdb
|
289 |
+
feature_extractor = WhisperFeatureExtractor.from_pretrained(
|
290 |
+
"openai/whisper-medium"
|
291 |
+
)
|
292 |
+
|
293 |
+
def whisper_prepare_dataset(batch):
|
294 |
+
# load and resample audio data from 48 to 16kHz
|
295 |
+
audio = batch["audio"]
|
296 |
+
|
297 |
+
# compute log-Mel input features from input audio array
|
298 |
+
batch["input_features"] = feature_extractor(
|
299 |
+
audio["array"], sampling_rate=audio["sampling_rate"]
|
300 |
+
).input_features[0]
|
301 |
+
|
302 |
+
# encode target text to label ids
|
303 |
+
batch["labels"] = tokenizer(batch["transcription"]).input_ids
|
304 |
+
return batch
|
305 |
+
|
306 |
+
torch.cuda.empty_cache()
|
307 |
+
|
308 |
+
|
309 |
+
def my_map_to_pred(batch):
|
310 |
+
# pdb.set_trace()
|
311 |
+
audio = batch["audio"]
|
312 |
+
input_features = processor(
|
313 |
+
audio["array"],
|
314 |
+
sampling_rate=audio["sampling_rate"],
|
315 |
+
return_tensors="pt",
|
316 |
+
).input_features
|
317 |
+
# batch["reference"] = whisper_processor.tokenizer._normalize(batch['text'])
|
318 |
+
batch["reference"] = processor.tokenizer._normalize(batch["transcription"])
|
319 |
+
|
320 |
+
with torch.no_grad():
|
321 |
+
# predicted_ids = whisper_model.generate(input_features.to("cuda"))[0]
|
322 |
+
predicted_ids = model.generate(input_features.to("cuda"))[0]
|
323 |
+
transcription = model.decode(predicted_ids)
|
324 |
+
batch["prediction"] = model.tokenizer._normalize(transcription)
|
325 |
+
return batch
|
326 |
+
|
327 |
+
|
328 |
+
from dataclasses import dataclass
|
329 |
+
from typing import Any, Dict, List, Union
|
330 |
+
|
331 |
+
import torch
|
332 |
+
|
333 |
+
|
334 |
+
@dataclass
|
335 |
+
class DataCollatorSpeechSeq2SeqWithPadding:
|
336 |
+
processor: Any
|
337 |
+
|
338 |
+
def __call__(
|
339 |
+
self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
|
340 |
+
) -> Dict[str, torch.Tensor]:
|
341 |
+
# split inputs and labels since they have to be of different lengths and need different padding methods
|
342 |
+
# first treat the audio inputs by simply returning torch tensors
|
343 |
+
input_features = [
|
344 |
+
{"input_features": feature["input_features"]}
|
345 |
+
for feature in features
|
346 |
+
]
|
347 |
+
batch = self.processor.feature_extractor.pad(
|
348 |
+
input_features, return_tensors="pt"
|
349 |
+
)
|
350 |
+
|
351 |
+
# get the tokenized label sequences
|
352 |
+
label_features = [
|
353 |
+
{"input_ids": feature["labels"]} for feature in features
|
354 |
+
]
|
355 |
+
# pad the labels to max length
|
356 |
+
labels_batch = self.processor.tokenizer.pad(
|
357 |
+
label_features, return_tensors="pt"
|
358 |
+
)
|
359 |
+
|
360 |
+
# replace padding with -100 to ignore loss correctly
|
361 |
+
labels = labels_batch["input_ids"].masked_fill(
|
362 |
+
labels_batch.attention_mask.ne(1), -100
|
363 |
+
)
|
364 |
+
|
365 |
+
# if bos token is appended in previous tokenization step,
|
366 |
+
# cut bos token here as it's append later anyways
|
367 |
+
if (
|
368 |
+
(labels[:, 0] == self.processor.tokenizer.bos_token_id)
|
369 |
+
.all()
|
370 |
+
.cpu()
|
371 |
+
.item()
|
372 |
+
):
|
373 |
+
labels = labels[:, 1:]
|
374 |
+
|
375 |
+
batch["labels"] = labels
|
376 |
+
|
377 |
+
return batch
|
378 |
+
|
379 |
+
|
380 |
+
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
|
381 |
+
|
382 |
+
|
383 |
+
def compute_metrics(pred):
|
384 |
+
pred_ids = pred.predictions
|
385 |
+
label_ids = pred.label_ids
|
386 |
+
|
387 |
+
# replace -100 with the pad_token_id
|
388 |
+
label_ids[label_ids == -100] = tokenizer.pad_token_id
|
389 |
+
|
390 |
+
# we do not want to group tokens when computing the metrics
|
391 |
+
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
392 |
+
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
|
393 |
+
|
394 |
+
wer = 100 * WER.compute(predictions=pred_str, references=label_str)
|
395 |
+
|
396 |
+
return {"wer": wer}
|
397 |
+
|
398 |
+
encoded_train = combine_train.map(whisper_prepare_dataset, num_proc=4)
|
399 |
+
encoded_dev = conbine_dev.map(whisper_prepare_dataset, num_proc=4)
|
400 |
+
encoded_test = conbine_test.map(whisper_prepare_dataset, num_proc=4)
|
401 |
+
|
402 |
+
# extra_test
|
403 |
+
|
404 |
+
encoded_ori_test = test.map(whisper_prepare_dataset, num_proc=4) # 7 / 16
|
405 |
+
|
406 |
+
encoded_Tony_test = Tony_test.map(whisper_prepare_dataset, num_proc=4) # 0 / 19
|
407 |
+
|
408 |
+
encoded_John_test = John_test.map(whisper_prepare_dataset, num_proc=4) # 0 / 30
|
409 |
+
# [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False]
|
410 |
+
encoded_Negel_test = Negel_test.map(whisper_prepare_dataset, num_proc=4) # 12 / 16
|
411 |
+
# [False, True, True, True, True, True, True, False, True, True, True, True, True, True, False, False]
|
412 |
+
encoded_Fary = Fary_PAL_test_dataset.map(whisper_prepare_dataset, num_proc=4) # 12 / 30
|
413 |
+
# [True, True, True, True, True, False, False, False, True, False, False, True, False, False, True, False, True, True, False, True, True, False, False, False, False, False, False, False, False, False]
|
414 |
+
encoded_healthy = healthy_test_dataset.map(whisper_prepare_dataset, num_proc=4) # 5 / 160
|
415 |
+
# [False, False, False, False, False, False, False, False, False, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, False, True, True, False, False, False, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False]
|
416 |
+
|
417 |
+
# Make sure the content Variablity
|
418 |
+
train_tuple = tuple(encoded_train['transcription'])
|
419 |
+
dev_tuple = tuple(encoded_dev['transcription'])
|
420 |
+
train_dev_tuple = (train_tuple + dev_tuple)
|
421 |
+
|
422 |
+
pdb.set_trace()
|
423 |
+
new_encoded_test = encoded_test.select(np.where(np.array([False if x in train_dev_tuple else True for x in encoded_test['transcription']]))[0])
|
424 |
+
new_encoded_ori_test = encoded_ori_test.select(np.where(np.array([False if x in train_dev_tuple else True for x in encoded_ori_test['transcription']]))[0])
|
425 |
+
new_encoded_Tony_test = encoded_Tony_test.select(np.where(np.array([False if x in train_dev_tuple else True for x in encoded_Tony_test['transcription']]))[0])
|
426 |
+
new_encoded_John_test = encoded_John_test.select(np.where(np.array([False if x in train_dev_tuple else True for x in encoded_John_test['transcription']]))[0])
|
427 |
+
new_encoded_Negel_test = encoded_Negel_test.select(np.where(np.array([False if x in train_dev_tuple else True for x in encoded_Negel_test['transcription']]))[0])
|
428 |
+
new_encoded_Fary = encoded_Fary.select(np.where(np.array([False if x in train_dev_tuple else True for x in encoded_Fary['transcription']]))[0])
|
429 |
+
new_encoded_healthy = encoded_healthy.select(np.where(np.array([False if x in train_dev_tuple else True for x in encoded_healthy['transcription']]))[0])
|
430 |
+
pdb.set_trace()
|
431 |
+
torch.cuda.empty_cache()
|
432 |
+
|
433 |
+
training_args = Seq2SeqTrainingArguments(
|
434 |
+
output_dir=fine_tuning_dir, # change to a repo name of your choice
|
435 |
+
per_device_train_batch_size=8,
|
436 |
+
gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
|
437 |
+
learning_rate=1e-5,
|
438 |
+
warmup_steps=50,
|
439 |
+
max_steps=1000,
|
440 |
+
gradient_checkpointing=True,
|
441 |
+
fp16=True,
|
442 |
+
evaluation_strategy="steps",
|
443 |
+
save_strategy="steps",
|
444 |
+
per_device_eval_batch_size=8,
|
445 |
+
predict_with_generate=True,
|
446 |
+
generation_max_length=512,
|
447 |
+
save_steps=20,
|
448 |
+
eval_steps=20,
|
449 |
+
logging_steps=10,
|
450 |
+
report_to=["tensorboard"],
|
451 |
+
load_best_model_at_end=True,
|
452 |
+
metric_for_best_model="wer",
|
453 |
+
greater_is_better=False,
|
454 |
+
save_total_limit=10,
|
455 |
+
push_to_hub=False,
|
456 |
+
)
|
457 |
+
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
458 |
+
|
459 |
+
trainer = Seq2SeqTrainer(
|
460 |
+
args=training_args,
|
461 |
+
model=model,
|
462 |
+
train_dataset=encoded_train,
|
463 |
+
eval_dataset=encoded_dev,
|
464 |
+
data_collator=data_collator,
|
465 |
+
compute_metrics=compute_metrics,
|
466 |
+
tokenizer=processor.feature_extractor,
|
467 |
+
callbacks=[EarlyStoppingCallback(early_stopping_patience=10)],
|
468 |
+
)
|
469 |
+
# callbacks=[EvalLoggingCallback()]
|
470 |
+
pdb.set_trace()
|
471 |
+
|
472 |
+
result_dict = {
|
473 |
+
"Ori_Test": trainer.evaluate(encoded_ori_test),
|
474 |
+
"Tony_Test": trainer.evaluate(encoded_Tony_test),
|
475 |
+
"John_Test": trainer.evaluate(encoded_John_test),
|
476 |
+
"Negel_Test": trainer.evaluate(encoded_Negel_test),
|
477 |
+
"Zeroshot_Fary_Test": trainer.evaluate(encoded_Fary),
|
478 |
+
"Healthy_Test": trainer.evaluate(encoded_healthy),
|
479 |
+
}
|
480 |
+
|
481 |
+
# print(result_dict)
|
482 |
+
|
483 |
+
pdb.set_trace()
|
484 |
+
trainer.evaluate(encoded_test)
|
485 |
+
trainer.evaluate(new_encoded_test)
|
486 |
+
new_result_dict = {
|
487 |
+
"Ori_Test": trainer.evaluate(new_encoded_ori_test), # 'eval_wer': 12.345679012345679,
|
488 |
+
"Tony_Test": trainer.evaluate(new_encoded_Tony_test), # 'eval_wer': 25.0,
|
489 |
+
"John_Test": trainer.evaluate(new_encoded_John_test),
|
490 |
+
"Negel_Test": trainer.evaluate(new_encoded_Negel_test), # 2.08
|
491 |
+
"Zeroshot_Fary_Test": trainer.evaluate(new_encoded_Fary), ## 11.49
|
492 |
+
"Healthy_Test": trainer.evaluate(new_encoded_healthy),
|
493 |
+
}
|
494 |
+
|
495 |
+
print(new_result_dict)
|
496 |
+
|
497 |
+
# pdb.set_trace()
|
498 |
+
# result_dict = {
|
499 |
+
# "Ori_Test": trainer.evaluate(encoded_ori_test),
|
500 |
+
# "Tony_Test": trainer.evaluate(encoded_Tony_test),
|
501 |
+
# "John_Test": trainer.evaluate(encoded_John_test),
|
502 |
+
# "Negel_Test": trainer.evaluate(encoded_Negel_test),
|
503 |
+
# "Zeroshot_Fary_Test": trainer.evaluate(encoded_Fary),
|
504 |
+
# "Healthy_Test": trainer.evaluate(encoded_healthy),
|
505 |
+
# }
|
506 |
+
|
507 |
+
# pdb.set_trace()
|
508 |
+
# # Evaluation
|
509 |
+
# model.push_to_hub("KevinGeng/%s" % id)
|
local/post_processing.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
# Post processing module for data recording
|
3 |
+
# Author: Kevin Geng @Laronix, Sep. 2022
|
4 |
+
|
5 |
+
# Load log.csv, generate standard wav files with selected samplerate, and calculate stastitical features
|
6 |
+
'''
|
7 |
+
|
8 |
+
from random import sample
|
9 |
+
import librosa
|
10 |
+
import soundfile as sf
|
11 |
+
import numpy as np
|
12 |
+
import pdb
|
13 |
+
from pathlib import Path
|
14 |
+
import sys
|
15 |
+
import pandas as pd
|
16 |
+
indir = Path(sys.argv[1])
|
17 |
+
assert indir.exists() == True
|
18 |
+
wavs = Path(indir/Path("Audio_to_Evaluate")).glob("**/*.wav")
|
19 |
+
log = Path(indir/Path("log.csv"))
|
20 |
+
|
21 |
+
# x = np.loadtxt(log, dtype=str, delimiter=",")
|
22 |
+
x = pd.read_csv(log, header=0)
|
23 |
+
|
24 |
+
# y, sr = librosa.load("/home/kevingeng/laronix_automos/Julianna/Audio_to_evaluate/tmp0kgcdpi2.wav", sr=48000)
|
25 |
+
outdir = indir/Path("output")
|
26 |
+
# pdb.set_trace()
|
27 |
+
# outdir_clean = indir/Path("output_clean")
|
28 |
+
Path.mkdir(outdir, exist_ok=True)
|
29 |
+
# Path.mkdir(outdir_clean, exist_ok=True)
|
30 |
+
## Capitalize E valuate
|
31 |
+
# for i, j in zip(x["Audio_to_Evaluate"], x["Reference_ID"]):
|
32 |
+
# y, sr = librosa.load(i, sr=48000)
|
33 |
+
# # kevin 1017 John's trial with original data.
|
34 |
+
# y_ = librosa.util.normalize(y, norm=5)
|
35 |
+
# y_cut, index = librosa.effects.trim(y_, top_db=30)
|
36 |
+
# # normalized and cut
|
37 |
+
# # pdb.set_trace()
|
38 |
+
# # sf.write(outdir/Path(str(indir)+"_"+ j +".wav"), y_cut, samplerate=sr)
|
39 |
+
# sf.write(outdir/Path(Path(indir).stem+"_"+ j +".wav"), y_cut, samplerate=sr)
|
40 |
+
|
41 |
+
def process_audio(file_path, ref_id, sr=48000, norm=5, top_db=30):
|
42 |
+
y, _ = librosa.load(file_path, sr=sr)
|
43 |
+
y_norm = librosa.util.normalize(y, norm=norm)
|
44 |
+
y_cut, _ = librosa.effects.trim(y_norm, top_db=top_db)
|
45 |
+
return y_cut
|
46 |
+
|
47 |
+
def save_audio(y_cut, ref_id, outdir, indir, sr=48000):
|
48 |
+
out_path = outdir / f"{Path(indir).stem}_{ref_id}.wav"
|
49 |
+
sf.write(out_path, y_cut, samplerate=sr)
|
50 |
+
|
51 |
+
def main(audio_files, ref_ids, outdir, indir):
|
52 |
+
for file_path, ref_id in zip(audio_files, ref_ids):
|
53 |
+
y_cut = process_audio(file_path, ref_id)
|
54 |
+
save_audio(y_cut, ref_id, outdir, indir)
|
55 |
+
|
56 |
+
main(x["Audio_to_Evaluate"], x["Reference_ID"], outdir, indir)
|
local/wer_plot_report.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import pandas as pd
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import sys
|
8 |
+
import pdb
|
9 |
+
|
10 |
+
threshold = 0.3
|
11 |
+
if __name__ == "__main__":
|
12 |
+
wer_csv = sys.argv[1]
|
13 |
+
df = pd.read_csv(wer_csv)
|
14 |
+
fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(25, 15))
|
15 |
+
|
16 |
+
# Hist for distribution
|
17 |
+
ax[0].set_xlabel("Word Error Rate")
|
18 |
+
ax[0].set_ylabel("Counts")
|
19 |
+
ax[0].set_xlim(left=0.0, right=df['wer'].max())
|
20 |
+
ax[0].hist(df['wer'], bins=50)
|
21 |
+
ax[0].axvline(x=threshold, color="r")
|
22 |
+
# plt.savefig("hist.png")
|
23 |
+
|
24 |
+
# Line curve for each sentences
|
25 |
+
colors = ['green' if x < threshold else 'red' for x in df['wer']]
|
26 |
+
|
27 |
+
new_ids = [str(x).split('.')[0] for x in df['id']]
|
28 |
+
ax[1].set_xlabel("IDs")
|
29 |
+
ax[1].set_ylabel("Word Error Rate")
|
30 |
+
ax[1].scatter(new_ids, df['wer'], c=colors, marker='o')
|
31 |
+
ax[1].vlines(new_ids, ymin=0, ymax=df['wer'], colors='grey', linestyle='dotted', label='Vertical Lines')
|
32 |
+
ax[1].axhline(y=threshold, xmin=0, xmax=len(new_ids), color='r')
|
33 |
+
|
34 |
+
# ax[0].axhline(y=threshold, color="black")
|
35 |
+
|
36 |
+
# for i, v in enumerate(df['wer']):
|
37 |
+
# plt.text(str(df['id'][i]).split('.')[0], -2, str(df['id'][i]), ha='center', fontsize=3)
|
38 |
+
|
39 |
+
ax[1].set_xticklabels(new_ids, rotation=90, fontsize=10)
|
40 |
+
ax[1].tick_params(axis='x', width=20)
|
41 |
+
# ax[1].set_xlim(10, len(df['id']) + 10)
|
42 |
+
plt.tight_layout()
|
43 |
+
pdb.set_trace()
|
44 |
+
# fig.savefig("%s/%s.png"%(Path(sys.argv[1]).parent, sys.argv[1].split('/')[-1]), format='png')
|
45 |
+
fig.savefig("%s.png"%(sys.argv[1]), format='png')
|
local/whisper_fine_tuning_large_with_negel.py
ADDED
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fine_tuning_dir = "/home/kevingeng/Disk2/laronix/laronix_automos/fine_tuned/SSD/model/Michael_52_with_Large_AVA_script_conv_train_conv_dev/checkpoint-60"
|
2 |
+
"""
|
3 |
+
TODO:
|
4 |
+
+ [x] Load Configuration
|
5 |
+
+ [ ] Multi ASR Engine
|
6 |
+
+ [ ] Batch / Real Time support
|
7 |
+
"""
|
8 |
+
from pathlib import Path
|
9 |
+
from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC, AutoProcessor
|
10 |
+
from datasets import load_dataset, concatenate_datasets
|
11 |
+
from datasets import Dataset, Audio
|
12 |
+
import pdb
|
13 |
+
import string
|
14 |
+
import librosa
|
15 |
+
# local import
|
16 |
+
import sys
|
17 |
+
|
18 |
+
sys.path.append("src")
|
19 |
+
import torch
|
20 |
+
torch.cuda.set_device("cuda:0")
|
21 |
+
# token_model = AutoModelForCTC.from_pretrained(
|
22 |
+
# "facebook/wav2vec2-base-960h"
|
23 |
+
# )
|
24 |
+
|
25 |
+
# audio_dir= "/Users/kevingeng/Laronix/laronix_automos/data/Patient_sil_trim_16k_normed_5_snr_40/"
|
26 |
+
audio_dir ="./data/Patient_sil_trim_16k_normed_5_snr_40"
|
27 |
+
healthy_dir="./data/Healthy"
|
28 |
+
Fary_PAL_30="./data/Fary_PAL_p326_20230110_30"
|
29 |
+
John_p326 = "./data/John_p326/output"
|
30 |
+
John_video = "./data/20230103_video"
|
31 |
+
p326_300_dir ="./data/John_p326_large"
|
32 |
+
negel_152 = "./data/4_negal_152_clean_all"
|
33 |
+
|
34 |
+
michael3_52 = "data/3_michael_20230619_52"
|
35 |
+
|
36 |
+
patient_T = "data/Patient_T/Patient_T"
|
37 |
+
patient_L = "data/Patient_L/Patient_L"
|
38 |
+
P1tony = "data/Participant1_Tony_Recording/CLEAN_SENTENCES/CONVERSATIONAL/PAL"
|
39 |
+
P1tony_arthur = "data/Participant1_Tony_Recording/CLEAN_SENTENCES/SCRIPTED/Arthur_the_Rat/PAL"
|
40 |
+
P1tony_rainbow = "data/Participant1_Tony_Recording/CLEAN_SENTENCES/SCRIPTED/Rainbow_Passage/Laronix"
|
41 |
+
|
42 |
+
def dataclean(example):
|
43 |
+
# pdb.set_trace()
|
44 |
+
if example['audio']['sampling_rate'] != 16000:
|
45 |
+
resampled_audio = librosa.resample(y=example['audio']['array'],
|
46 |
+
orig_sr= example['audio']['sampling_rate'],
|
47 |
+
target_sr=16000)
|
48 |
+
# torchaudio.transforms.Resample(example['audio']['sampling_rate'], 16000)
|
49 |
+
# resampled_audio = resampler(example['audio']['array'])
|
50 |
+
|
51 |
+
return {"audio": {"path": example['audio']['path'], "array": resampled_audio, "sampling_rate": 16000},
|
52 |
+
"transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))}
|
53 |
+
else:
|
54 |
+
return {"transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))}
|
55 |
+
|
56 |
+
# patient_L_test_dataset = load_dataset("audiofolder", data_dir=patient_L, split="train")
|
57 |
+
# patient_L_test_dataset = patient_L_test_dataset.map(dataclean)
|
58 |
+
|
59 |
+
# patient_T_test_dataset = load_dataset("audiofolder", data_dir=patient_T, split="train")
|
60 |
+
# patient_T_test_dataset = patient_T_test_dataset.map(dataclean)
|
61 |
+
|
62 |
+
P1tony_dataset = load_dataset("audiofolder", data_dir=P1tony, split="train")
|
63 |
+
P1tony_dataset = P1tony_dataset.map(dataclean)
|
64 |
+
|
65 |
+
P3Micheal_dataset_52 = load_dataset("audiofolder", data_dir=michael3_52, split="train")
|
66 |
+
P3Micheal_dataset_52 = P3Micheal_dataset_52.map(dataclean)
|
67 |
+
|
68 |
+
# negel_152_dataset = load_dataset("audiofolder", data_dir=negel_152, split="train")
|
69 |
+
# negel_152_dataset = negel_152_dataset.map(dataclean)
|
70 |
+
|
71 |
+
|
72 |
+
# pdb.set_trace()
|
73 |
+
# P1tony_scripted1 = load_dataset("audiofolder", data_dir=P1tony_rainbow, split="train")
|
74 |
+
# P1tony_scripted2 = load_dataset("audiofolder", data_dir=P1tony_arthur, split="train")
|
75 |
+
# P1tony_scripted1 = P1tony_scripted1.map(dataclean)
|
76 |
+
# P1tony_scripted2 = P1tony_scripted2.map(dataclean)
|
77 |
+
# P1tony_scripted = concatenate_datasets([P1tony_scripted1, P1tony_scripted2])
|
78 |
+
|
79 |
+
# audio_dir ="/home/kevingeng/laronix/laronix_automos/data/Healthy"
|
80 |
+
# tgt_audio_dir= "/Users/kevingeng/Laronix/Dataset/Pneumatic/automos"
|
81 |
+
|
82 |
+
# Get Transcription, WER and PPM
|
83 |
+
"""
|
84 |
+
TODO:
|
85 |
+
[DONE]: Automatic generating Config
|
86 |
+
"""
|
87 |
+
|
88 |
+
import yaml
|
89 |
+
import argparse
|
90 |
+
import sys
|
91 |
+
from pathlib import Path
|
92 |
+
|
93 |
+
sys.path.append("./src")
|
94 |
+
import lightning_module
|
95 |
+
from UV import plot_UV, get_speech_interval
|
96 |
+
from transformers import pipeline
|
97 |
+
from rich.progress import track
|
98 |
+
from rich import print as rprint
|
99 |
+
import numpy as np
|
100 |
+
import jiwer
|
101 |
+
import pdb
|
102 |
+
import torch.nn as nn
|
103 |
+
import torch
|
104 |
+
import torchaudio
|
105 |
+
import gradio as gr
|
106 |
+
from sys import flags
|
107 |
+
from random import sample
|
108 |
+
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
109 |
+
|
110 |
+
import evaluate
|
111 |
+
|
112 |
+
wer = evaluate.load("wer")
|
113 |
+
|
114 |
+
# root_path = Path(__file__).parents[1]
|
115 |
+
|
116 |
+
class ChangeSampleRate(nn.Module):
|
117 |
+
def __init__(self, input_rate: int, output_rate: int):
|
118 |
+
super().__init__()
|
119 |
+
self.output_rate = output_rate
|
120 |
+
self.input_rate = input_rate
|
121 |
+
|
122 |
+
def forward(self, wav: torch.tensor) -> torch.tensor:
|
123 |
+
# Only accepts 1-channel waveform input
|
124 |
+
wav = wav.view(wav.size(0), -1)
|
125 |
+
new_length = wav.size(-1) * self.output_rate // self.input_rate
|
126 |
+
indices = torch.arange(new_length) * (
|
127 |
+
self.input_rate / self.output_rate
|
128 |
+
)
|
129 |
+
round_down = wav[:, indices.long()]
|
130 |
+
round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
|
131 |
+
output = round_down * (1.0 - indices.fmod(1.0)).unsqueeze(
|
132 |
+
0
|
133 |
+
) + round_up * indices.fmod(1.0).unsqueeze(0)
|
134 |
+
return output
|
135 |
+
|
136 |
+
# resample and clean text data
|
137 |
+
def dataclean(example):
|
138 |
+
# pdb.set_trace()
|
139 |
+
if example['audio']['sampling_rate'] != 16000:
|
140 |
+
resampled_audio = librosa.resample(y=example['audio']['array'],
|
141 |
+
orig_sr= example['audio']['sampling_rate'],
|
142 |
+
target_sr=16000)
|
143 |
+
# torchaudio.transforms.Resample(example['audio']['sampling_rate'], 16000)
|
144 |
+
# resampled_audio = resampler(example['audio']['array'])
|
145 |
+
|
146 |
+
return {"audio": {"path": example['audio']['path'], "array": resampled_audio, "sampling_rate": 16000},
|
147 |
+
"transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))}
|
148 |
+
else:
|
149 |
+
return {"transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))}
|
150 |
+
|
151 |
+
# processor = AutoFeatureExtractor.from_pretrained(
|
152 |
+
# "facebook/wav2vec2-base-960h"
|
153 |
+
# )
|
154 |
+
processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
|
155 |
+
|
156 |
+
def prepare_dataset(batch):
|
157 |
+
audio = batch["audio"]
|
158 |
+
batch = processor(audio["array"], sampling_rate = audio["sampling_rate"], text=batch['transcription'])
|
159 |
+
batch["input_length"] = len(batch["input_values"][0])
|
160 |
+
return batch
|
161 |
+
|
162 |
+
src_dataset = load_dataset("audiofolder", data_dir=audio_dir, split="train")
|
163 |
+
src_dataset = src_dataset.map(dataclean)
|
164 |
+
p326_300_dataset = load_dataset("audiofolder", data_dir=p326_300_dir, split="train")
|
165 |
+
p326_300_dataset = p326_300_dataset.map(dataclean)
|
166 |
+
|
167 |
+
# healthy_test_dataset = load_dataset("audiofolder", data_dir=healthy_dir, split='train')
|
168 |
+
# healthy_test_dataset = healthy_test_dataset.map(dataclean)
|
169 |
+
|
170 |
+
# Fary_PAL_test_dataset = load_dataset("audiofolder", data_dir=Fary_PAL_30, split='train')
|
171 |
+
# Fary_PAL_test_dataset = Fary_PAL_test_dataset.map(dataclean)
|
172 |
+
|
173 |
+
# John_p326_test_dataset = load_dataset("audiofolder", data_dir=John_p326, split='train')
|
174 |
+
# John_p326_test_dataset = John_p326_test_dataset.map(dataclean)
|
175 |
+
|
176 |
+
# John_video_test_dataset = load_dataset("audiofolder", data_dir=John_video, split='train')
|
177 |
+
# John_video_test_dataset = John_video_test_dataset.map(dataclean)
|
178 |
+
|
179 |
+
# pdb.set_trace()
|
180 |
+
|
181 |
+
def train_dev_test_split(dataset: Dataset, dev_rate=0.1, test_rate=0.1, seed=1):
|
182 |
+
"""
|
183 |
+
input: dataset
|
184 |
+
dev_rate,
|
185 |
+
test_rate
|
186 |
+
seed
|
187 |
+
-------
|
188 |
+
Output:
|
189 |
+
dataset_dict{"train", "dev", "test"}
|
190 |
+
"""
|
191 |
+
train_dev_test = dataset.train_test_split(test_size=test_rate, seed=seed)
|
192 |
+
test = train_dev_test["test"]
|
193 |
+
train_dev = train_dev_test['train']
|
194 |
+
|
195 |
+
# pdb.set_trace()
|
196 |
+
if len(train_dev) <= int(len(dataset)*dev_rate):
|
197 |
+
train = Dataset.from_dict({"audio": [], "transcription": []})
|
198 |
+
dev = train_dev
|
199 |
+
else:
|
200 |
+
train_dev = train_dev.train_test_split(test_size=int(len(dataset)*dev_rate), seed=seed)
|
201 |
+
train = train_dev['train']
|
202 |
+
dev = train_dev['test']
|
203 |
+
return train, dev, test
|
204 |
+
|
205 |
+
# pdb.set_trace()
|
206 |
+
# P1tony_train, P1tony_dev, P1tony_test = train_dev_test_split(P1tony_dataset, dev_rate=0.5, test_rate=0.5, seed=1)
|
207 |
+
# P1tony_train_ = concatenate_datasets([P1tony_train,P1tony_scripted])
|
208 |
+
# pdb.set_trace()
|
209 |
+
|
210 |
+
Michael_52_train, Michael_52_dev, Michael_52_test = train_dev_test_split(P3Micheal_dataset_52, dev_rate=0.1.5, test_rate=0.15, seed=1)
|
211 |
+
|
212 |
+
# train_dev / test
|
213 |
+
ds = src_dataset.train_test_split(test_size=0.1, seed=1)
|
214 |
+
|
215 |
+
# dataset_libri = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
216 |
+
|
217 |
+
train_dev = ds['train']
|
218 |
+
# train / dev
|
219 |
+
train_dev = train_dev.train_test_split(test_size=int(len(src_dataset)*0.1), seed=1)
|
220 |
+
# train/dev/test
|
221 |
+
train = train_dev['train']
|
222 |
+
test = ds['test']
|
223 |
+
dev = train_dev['test']
|
224 |
+
|
225 |
+
encoded_train = train.map(prepare_dataset, num_proc=4)
|
226 |
+
encoded_dev = dev.map(prepare_dataset, num_proc=4)
|
227 |
+
encoded_test = test.map(prepare_dataset, num_proc=4)
|
228 |
+
p326_encoded_train = p326_300_dataset.map(prepare_dataset, num_proc=4)
|
229 |
+
|
230 |
+
# combine large p326 in to training set
|
231 |
+
encoded_train = concatenate_datasets([encoded_train, p326_encoded_train])
|
232 |
+
|
233 |
+
# encoded_healthy = healthy_test_dataset.map(prepare_dataset, num_proc=4)
|
234 |
+
# encoded_Fary = Fary_PAL_test_dataset.map(prepare_dataset, num_proc=4)
|
235 |
+
# encoded_John_p326 = John_p326_test_dataset.map(prepare_dataset, num_proc=4)
|
236 |
+
# encoded_John_video = John_video_test_dataset.map(prepare_dataset, num_proc=4)
|
237 |
+
|
238 |
+
# encoded_P1tony_train = P1tony_train.map(prepare_dataset, num_proc=4)
|
239 |
+
# encoded_P1tony_dev = P1tony_dev.map(prepare_dataset, num_proc=4)
|
240 |
+
# encoded_P1tony_test = P1tony_test.map(prepare_dataset, num_proc=4)
|
241 |
+
|
242 |
+
# pdb.set_trace()
|
243 |
+
import numpy as np
|
244 |
+
|
245 |
+
WER = evaluate.load("wer")
|
246 |
+
|
247 |
+
## Whisper decoding
|
248 |
+
|
249 |
+
from transformers import WhisperProcessor, WhisperForConditionalGeneration, WhisperTokenizer, WhisperFeatureExtractor, Seq2SeqTrainingArguments, Seq2SeqTrainer, WhisperModel
|
250 |
+
processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
|
251 |
+
# model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium").to("cuda:0")
|
252 |
+
model = WhisperForConditionalGeneration.from_pretrained("./fine_tuned/whipser_medium_en_PAL300_step25_step2_VCTK/checkpoint-400", use_auth_token=True).to("cuda:0")
|
253 |
+
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-medium", language="English", task="transcribe")
|
254 |
+
|
255 |
+
from pathlib import Path
|
256 |
+
id = Path(fine_tuning_dir).stem
|
257 |
+
pdb.set_trace()
|
258 |
+
tokenizer.push_to_hub("KevinGeng/%s"%id)
|
259 |
+
# import pdb
|
260 |
+
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-medium")
|
261 |
+
|
262 |
+
def whisper_prepare_dataset(batch):
|
263 |
+
# load and resample audio data from 48 to 16kHz
|
264 |
+
audio = batch["audio"]
|
265 |
+
|
266 |
+
# compute log-Mel input features from input audio array
|
267 |
+
batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
|
268 |
+
|
269 |
+
# encode target text to label ids
|
270 |
+
batch["labels"] = tokenizer(batch["transcription"]).input_ids
|
271 |
+
return batch
|
272 |
+
|
273 |
+
torch.cuda.empty_cache()
|
274 |
+
|
275 |
+
def my_map_to_pred(batch):
|
276 |
+
# pdb.set_trace()
|
277 |
+
audio = batch["audio"]
|
278 |
+
input_features = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt").input_features
|
279 |
+
# batch["reference"] = whisper_processor.tokenizer._normalize(batch['text'])
|
280 |
+
batch["reference"] = processor.tokenizer._normalize(batch['transcription'])
|
281 |
+
|
282 |
+
with torch.no_grad():
|
283 |
+
# predicted_ids = whisper_model.generate(input_features.to("cuda"))[0]
|
284 |
+
predicted_ids = model.generate(input_features.to("cuda"))[0]
|
285 |
+
transcription = model.decode(predicted_ids)
|
286 |
+
batch["prediction"] = model.tokenizer._normalize(transcription)
|
287 |
+
return batch
|
288 |
+
|
289 |
+
import torch
|
290 |
+
|
291 |
+
from dataclasses import dataclass
|
292 |
+
from typing import Any, Dict, List, Union
|
293 |
+
|
294 |
+
@dataclass
|
295 |
+
class DataCollatorSpeechSeq2SeqWithPadding:
|
296 |
+
processor: Any
|
297 |
+
|
298 |
+
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
299 |
+
# split inputs and labels since they have to be of different lengths and need different padding methods
|
300 |
+
# first treat the audio inputs by simply returning torch tensors
|
301 |
+
input_features = [{"input_features": feature["input_features"]} for feature in features]
|
302 |
+
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
|
303 |
+
|
304 |
+
# get the tokenized label sequences
|
305 |
+
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
306 |
+
# pad the labels to max length
|
307 |
+
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
|
308 |
+
|
309 |
+
# replace padding with -100 to ignore loss correctly
|
310 |
+
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
|
311 |
+
|
312 |
+
# if bos token is appended in previous tokenization step,
|
313 |
+
# cut bos token here as it's append later anyways
|
314 |
+
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
|
315 |
+
labels = labels[:, 1:]
|
316 |
+
|
317 |
+
batch["labels"] = labels
|
318 |
+
|
319 |
+
return batch
|
320 |
+
|
321 |
+
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
|
322 |
+
|
323 |
+
def compute_metrics(pred):
|
324 |
+
pred_ids = pred.predictions
|
325 |
+
label_ids = pred.label_ids
|
326 |
+
|
327 |
+
# replace -100 with the pad_token_id
|
328 |
+
label_ids[label_ids == -100] = tokenizer.pad_token_id
|
329 |
+
|
330 |
+
# we do not want to group tokens when computing the metrics
|
331 |
+
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
332 |
+
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
|
333 |
+
|
334 |
+
wer = 100 * WER.compute(predictions=pred_str, references=label_str)
|
335 |
+
|
336 |
+
return {"wer": wer}
|
337 |
+
|
338 |
+
# whisper_train = train.map(whisper_prepare_dataset, num_proc=4)
|
339 |
+
# pdb.set_trace()
|
340 |
+
whisper_train_large = encoded_train.map(whisper_prepare_dataset, num_proc=4)
|
341 |
+
whisper_dev = dev.map(whisper_prepare_dataset, num_proc=4)
|
342 |
+
whisper_test = test.map(whisper_prepare_dataset, num_proc=4)
|
343 |
+
|
344 |
+
encoded_Michael_52_train = Michael_52_train.map(whisper_prepare_dataset, num_proc=4)
|
345 |
+
encoded_Michael_52_dev = Michael_52_dev.map(whisper_prepare_dataset, num_proc=4)
|
346 |
+
encoded_Michael_52_test = Michael_52_test.map(whisper_prepare_dataset, num_proc=4)
|
347 |
+
# pdb.set_trace()
|
348 |
+
# # Add scirtped tony
|
349 |
+
# encoded_P1tony_train = P1tony_train_.map(whisper_prepare_dataset, num_proc=4)
|
350 |
+
# encoded_P1tony_dev = P1tony_dev.map(whisper_prepare_dataset, num_proc=4)
|
351 |
+
# encoded_P1tony_test = P1tony_test.map(whisper_prepare_dataset, num_proc=4)
|
352 |
+
|
353 |
+
# encode_negel_152_train = negel_152_train.map(whisper_prepare_dataset, num_proc=4)
|
354 |
+
# encode_negel_152_dev = negel_152_dev.map(whisper_prepare_dataset, num_proc=4)
|
355 |
+
# encode_negel_152_test = negel_152_test.map(whisper_prepare_dataset, num_proc=4)
|
356 |
+
|
357 |
+
# encoded_train_large = concatenate_datasets([whisper_train_large, encode_negel_152_train])
|
358 |
+
# encoded_dev_large = concatenate_datasets([whisper_dev, encode_negel_152_dev])
|
359 |
+
|
360 |
+
pdb.set_trace()
|
361 |
+
torch.cuda.empty_cache()
|
362 |
+
|
363 |
+
training_args = Seq2SeqTrainingArguments(
|
364 |
+
output_dir=fine_tuning_dir, # change to a repo name of your choice
|
365 |
+
per_device_train_batch_size=8,
|
366 |
+
gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
|
367 |
+
learning_rate=1e-5,
|
368 |
+
warmup_steps=50,
|
369 |
+
max_steps=1000,
|
370 |
+
gradient_checkpointing=True,
|
371 |
+
fp16=True,
|
372 |
+
evaluation_strategy="steps",
|
373 |
+
save_strategy="steps",
|
374 |
+
per_device_eval_batch_size=8,
|
375 |
+
predict_with_generate=True,
|
376 |
+
generation_max_length=512,
|
377 |
+
save_steps=10,
|
378 |
+
eval_steps=10,
|
379 |
+
logging_steps=10,
|
380 |
+
report_to=["tensorboard"],
|
381 |
+
load_best_model_at_end=True,
|
382 |
+
metric_for_best_model="wer",
|
383 |
+
greater_is_better=False,
|
384 |
+
save_total_limit=5,
|
385 |
+
push_to_hub=False,
|
386 |
+
)
|
387 |
+
from transformers import Trainer, TrainingArguments, EarlyStoppingCallback
|
388 |
+
|
389 |
+
# pdb.set_trace()
|
390 |
+
# # from transformers.trainer.callbacks import TensorBoardCallback
|
391 |
+
# class EvalLoggingCallback(TrainerCallback):
|
392 |
+
# def on_evaluate(self, args, state, control, metrics, **kwargs):
|
393 |
+
# print(f"Eval loss: {metrics['eval_loss']:.4f}, Accuracy: {metrics['eval_wer']:.4f}")
|
394 |
+
|
395 |
+
# pdb.set_trace()
|
396 |
+
|
397 |
+
trainer = Seq2SeqTrainer(
|
398 |
+
args=training_args,
|
399 |
+
model=model,
|
400 |
+
train_dataset=encoded_Michael_52_train,
|
401 |
+
eval_dataset=encoded_Michael_52_dev,
|
402 |
+
data_collator=data_collator,
|
403 |
+
compute_metrics=compute_metrics,
|
404 |
+
tokenizer=processor.feature_extractor,
|
405 |
+
callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
|
406 |
+
|
407 |
+
)
|
408 |
+
# callbacks=[EvalLoggingCallback()]
|
409 |
+
trainer.train()
|
410 |
+
# trainer.evaluate(encoded_P1tony_test, metrix_key_prefix="test")
|
411 |
+
# trainer.callback_handler.on_test_end(trainer, datasets=encoded_P1tony_test)
|
412 |
+
|
413 |
+
|
414 |
+
# ## Not fine tuned
|
415 |
+
# z_result = encoded_test.map(my_map_to_pred)·
|
416 |
+
# # pdb.set_trace()
|
417 |
+
# # 0.4692737430167598
|
418 |
+
# z = WER.compute(references=z_result['reference'], predictions=z_result['prediction'])
|
419 |
+
|
420 |
+
# z_hel_result = encoded_healthy.map(my_map_to_pred)
|
421 |
+
# #
|
422 |
+
# z_hel = WER.compute(references=z_hel_result['reference'], predictions=z_hel_result['prediction'])
|
423 |
+
# # 0.1591610117211598
|
424 |
+
|
425 |
+
# z_fary_result = encoded_Fary.map(my_map_to_pred)
|
426 |
+
# z_far = WER.compute(references=z_fary_result['reference'], predictions=z_fary_result['prediction'])
|
427 |
+
# # 0.1791044776119403
|
428 |
+
|
429 |
+
|
430 |
+
# z_john_p326_result = encoded_John_p326.map(my_map_to_pred)
|
431 |
+
# z_john_p326 = WER.compute(references=z_john_p326_result['reference'], predictions=z_john_p326_result['prediction'])
|
432 |
+
# # 0.4648241206030151
|
433 |
+
|
434 |
+
# # y_John_video= fine_tuned_trainer.predict(encoded_John_video)
|
435 |
+
# # metrics={'test_loss': 2.665189743041992, 'test_wer': 0.7222222222222222, 'test_runtime': 0.1633, 'test_samples_per_second': 48.979, 'test_steps_per_second': 6.122})
|
436 |
+
# pdb.set_trace()
|
437 |
+
|
438 |
+
# p326 training
|
439 |
+
# metrics={'test_loss': 0.4804028868675232, 'test_wer': 0.21787709497206703, 'test_runtime': 0.3594, 'test_samples_per_second': 44.517, 'test_steps_per_second': 5.565})
|
440 |
+
# hel metrics={'test_loss': 1.6363693475723267, 'test_wer': 0.17951881554595928, 'test_runtime': 3.8451, 'test_samples_per_second': 41.611, 'test_steps_per_second': 5.201})
|
441 |
+
# Fary: metrics={'test_loss': 1.4633615016937256, 'test_wer': 0.5572139303482587, 'test_runtime': 0.6627, 'test_samples_per_second': 45.27, 'test_steps_per_second': 6.036})
|
442 |
+
# p326 large: metrics={'test_loss': 0.6568527817726135, 'test_wer': 0.2889447236180904, 'test_runtime': 0.7169, 'test_samples_per_second': 51.613, 'test_steps_per_second': 6.975})
|
local/whisper_fine_tuning_michael_100.py
ADDED
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fine_tuning_dir = "fine_tuned/SSD/model/Michael_100_with_Large_AVA_script_conv_train_conv_dev"
|
2 |
+
"""
|
3 |
+
TODO:
|
4 |
+
+ [x] Load Configuration
|
5 |
+
+ [ ] Multi ASR Engine
|
6 |
+
+ [ ] Batch / Real Time support
|
7 |
+
"""
|
8 |
+
from pathlib import Path
|
9 |
+
from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC, AutoProcessor
|
10 |
+
from datasets import load_dataset, concatenate_datasets
|
11 |
+
from datasets import Dataset, Audio
|
12 |
+
import pdb
|
13 |
+
import string
|
14 |
+
import librosa
|
15 |
+
# local import
|
16 |
+
import sys
|
17 |
+
|
18 |
+
sys.path.append("src")
|
19 |
+
import torch
|
20 |
+
torch.cuda.set_device("cuda:0")
|
21 |
+
# token_model = AutoModelForCTC.from_pretrained(
|
22 |
+
# "facebook/wav2vec2-base-960h"
|
23 |
+
# )
|
24 |
+
|
25 |
+
# audio_dir= "/Users/kevingeng/Laronix/laronix_automos/data/Patient_sil_trim_16k_normed_5_snr_40/"
|
26 |
+
audio_dir ="./data/Patient_sil_trim_16k_normed_5_snr_40"
|
27 |
+
healthy_dir="./data/Healthy"
|
28 |
+
Fary_PAL_30="./data/Fary_PAL_p326_20230110_30"
|
29 |
+
John_p326 = "./data/John_p326/output"
|
30 |
+
John_video = "./data/20230103_video"
|
31 |
+
p326_300_dir ="./data/John_p326_large"
|
32 |
+
negel_152 = "./data/4_negal_152_clean_all"
|
33 |
+
|
34 |
+
michael3_52 = "data/3_michael_20230619_100"
|
35 |
+
|
36 |
+
patient_T = "data/Patient_T/Patient_T"
|
37 |
+
patient_L = "data/Patient_L/Patient_L"
|
38 |
+
P1tony = "data/Participant1_Tony_Recording/CLEAN_SENTENCES/CONVERSATIONAL/PAL"
|
39 |
+
P1tony_arthur = "data/Participant1_Tony_Recording/CLEAN_SENTENCES/SCRIPTED/Arthur_the_Rat/PAL"
|
40 |
+
P1tony_rainbow = "data/Participant1_Tony_Recording/CLEAN_SENTENCES/SCRIPTED/Rainbow_Passage/Laronix"
|
41 |
+
|
42 |
+
def dataclean(example):
|
43 |
+
# pdb.set_trace()
|
44 |
+
if example['audio']['sampling_rate'] != 16000:
|
45 |
+
resampled_audio = librosa.resample(y=example['audio']['array'],
|
46 |
+
orig_sr= example['audio']['sampling_rate'],
|
47 |
+
target_sr=16000)
|
48 |
+
# torchaudio.transforms.Resample(example['audio']['sampling_rate'], 16000)
|
49 |
+
# resampled_audio = resampler(example['audio']['array'])
|
50 |
+
|
51 |
+
return {"audio": {"path": example['audio']['path'], "array": resampled_audio, "sampling_rate": 16000},
|
52 |
+
"transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))}
|
53 |
+
else:
|
54 |
+
return {"transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))}
|
55 |
+
|
56 |
+
# patient_L_test_dataset = load_dataset("audiofolder", data_dir=patient_L, split="train")
|
57 |
+
# patient_L_test_dataset = patient_L_test_dataset.map(dataclean)
|
58 |
+
|
59 |
+
# patient_T_test_dataset = load_dataset("audiofolder", data_dir=patient_T, split="train")
|
60 |
+
# patient_T_test_dataset = patient_T_test_dataset.map(dataclean)
|
61 |
+
|
62 |
+
P1tony_dataset = load_dataset("audiofolder", data_dir=P1tony, split="train")
|
63 |
+
P1tony_dataset = P1tony_dataset.map(dataclean)
|
64 |
+
|
65 |
+
P3Micheal_dataset_52 = load_dataset("audiofolder", data_dir=michael3_52, split="train")
|
66 |
+
P3Micheal_dataset_52 = P3Micheal_dataset_52.map(dataclean)
|
67 |
+
|
68 |
+
# negel_152_dataset = load_dataset("audiofolder", data_dir=negel_152, split="train")
|
69 |
+
# negel_152_dataset = negel_152_dataset.map(dataclean)
|
70 |
+
|
71 |
+
|
72 |
+
# pdb.set_trace()
|
73 |
+
# P1tony_scripted1 = load_dataset("audiofolder", data_dir=P1tony_rainbow, split="train")
|
74 |
+
# P1tony_scripted2 = load_dataset("audiofolder", data_dir=P1tony_arthur, split="train")
|
75 |
+
# P1tony_scripted1 = P1tony_scripted1.map(dataclean)
|
76 |
+
# P1tony_scripted2 = P1tony_scripted2.map(dataclean)
|
77 |
+
# P1tony_scripted = concatenate_datasets([P1tony_scripted1, P1tony_scripted2])
|
78 |
+
|
79 |
+
# audio_dir ="/home/kevingeng/laronix/laronix_automos/data/Healthy"
|
80 |
+
# tgt_audio_dir= "/Users/kevingeng/Laronix/Dataset/Pneumatic/automos"
|
81 |
+
|
82 |
+
# Get Transcription, WER and PPM
|
83 |
+
"""
|
84 |
+
TODO:
|
85 |
+
[DONE]: Automatic generating Config
|
86 |
+
"""
|
87 |
+
|
88 |
+
import yaml
|
89 |
+
import argparse
|
90 |
+
import sys
|
91 |
+
from pathlib import Path
|
92 |
+
|
93 |
+
sys.path.append("./src")
|
94 |
+
import lightning_module
|
95 |
+
# from UV import plot_UV, get_speech_interval
|
96 |
+
from transformers import pipeline
|
97 |
+
from rich.progress import track
|
98 |
+
from rich import print as rprint
|
99 |
+
import numpy as np
|
100 |
+
import jiwer
|
101 |
+
import pdb
|
102 |
+
import torch.nn as nn
|
103 |
+
import torch
|
104 |
+
import torchaudio
|
105 |
+
import gradio as gr
|
106 |
+
from sys import flags
|
107 |
+
from random import sample
|
108 |
+
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
109 |
+
|
110 |
+
import evaluate
|
111 |
+
|
112 |
+
wer = evaluate.load("wer")
|
113 |
+
|
114 |
+
# root_path = Path(__file__).parents[1]
|
115 |
+
|
116 |
+
class ChangeSampleRate(nn.Module):
|
117 |
+
def __init__(self, input_rate: int, output_rate: int):
|
118 |
+
super().__init__()
|
119 |
+
self.output_rate = output_rate
|
120 |
+
self.input_rate = input_rate
|
121 |
+
|
122 |
+
def forward(self, wav: torch.tensor) -> torch.tensor:
|
123 |
+
# Only accepts 1-channel waveform input
|
124 |
+
wav = wav.view(wav.size(0), -1)
|
125 |
+
new_length = wav.size(-1) * self.output_rate // self.input_rate
|
126 |
+
indices = torch.arange(new_length) * (
|
127 |
+
self.input_rate / self.output_rate
|
128 |
+
)
|
129 |
+
round_down = wav[:, indices.long()]
|
130 |
+
round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
|
131 |
+
output = round_down * (1.0 - indices.fmod(1.0)).unsqueeze(
|
132 |
+
0
|
133 |
+
) + round_up * indices.fmod(1.0).unsqueeze(0)
|
134 |
+
return output
|
135 |
+
|
136 |
+
# resample and clean text data
|
137 |
+
def dataclean(example):
|
138 |
+
# pdb.set_trace()
|
139 |
+
if example['audio']['sampling_rate'] != 16000:
|
140 |
+
resampled_audio = librosa.resample(y=example['audio']['array'],
|
141 |
+
orig_sr= example['audio']['sampling_rate'],
|
142 |
+
target_sr=16000)
|
143 |
+
# torchaudio.transforms.Resample(example['audio']['sampling_rate'], 16000)
|
144 |
+
# resampled_audio = resampler(example['audio']['array'])
|
145 |
+
|
146 |
+
return {"audio": {"path": example['audio']['path'], "array": resampled_audio, "sampling_rate": 16000},
|
147 |
+
"transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))}
|
148 |
+
else:
|
149 |
+
return {"transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))}
|
150 |
+
|
151 |
+
# processor = AutoFeatureExtractor.from_pretrained(
|
152 |
+
# "facebook/wav2vec2-base-960h"
|
153 |
+
# )
|
154 |
+
processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
|
155 |
+
|
156 |
+
def prepare_dataset(batch):
|
157 |
+
audio = batch["audio"]
|
158 |
+
batch = processor(audio["array"], sampling_rate = audio["sampling_rate"], text=batch['transcription'])
|
159 |
+
batch["input_length"] = len(batch["input_values"][0])
|
160 |
+
return batch
|
161 |
+
|
162 |
+
src_dataset = load_dataset("audiofolder", data_dir=audio_dir, split="train")
|
163 |
+
src_dataset = src_dataset.map(dataclean)
|
164 |
+
p326_300_dataset = load_dataset("audiofolder", data_dir=p326_300_dir, split="train")
|
165 |
+
p326_300_dataset = p326_300_dataset.map(dataclean)
|
166 |
+
|
167 |
+
# healthy_test_dataset = load_dataset("audiofolder", data_dir=healthy_dir, split='train')
|
168 |
+
# healthy_test_dataset = healthy_test_dataset.map(dataclean)
|
169 |
+
|
170 |
+
# Fary_PAL_test_dataset = load_dataset("audiofolder", data_dir=Fary_PAL_30, split='train')
|
171 |
+
# Fary_PAL_test_dataset = Fary_PAL_test_dataset.map(dataclean)
|
172 |
+
|
173 |
+
# John_p326_test_dataset = load_dataset("audiofolder", data_dir=John_p326, split='train')
|
174 |
+
# John_p326_test_dataset = John_p326_test_dataset.map(dataclean)
|
175 |
+
|
176 |
+
# John_video_test_dataset = load_dataset("audiofolder", data_dir=John_video, split='train')
|
177 |
+
# John_video_test_dataset = John_video_test_dataset.map(dataclean)
|
178 |
+
|
179 |
+
# pdb.set_trace()
|
180 |
+
|
181 |
+
def train_dev_test_split(dataset: Dataset, dev_rate=0.1, test_rate=0.1, seed=1):
|
182 |
+
"""
|
183 |
+
input: dataset
|
184 |
+
dev_rate,
|
185 |
+
test_rate
|
186 |
+
seed
|
187 |
+
-------
|
188 |
+
Output:
|
189 |
+
dataset_dict{"train", "dev", "test"}
|
190 |
+
"""
|
191 |
+
train_dev_test = dataset.train_test_split(test_size=test_rate, seed=seed)
|
192 |
+
test = train_dev_test["test"]
|
193 |
+
train_dev = train_dev_test['train']
|
194 |
+
|
195 |
+
# pdb.set_trace()
|
196 |
+
if len(train_dev) <= int(len(dataset)*dev_rate):
|
197 |
+
train = Dataset.from_dict({"audio": [], "transcription": []})
|
198 |
+
dev = train_dev
|
199 |
+
else:
|
200 |
+
train_dev = train_dev.train_test_split(test_size=int(len(dataset)*dev_rate), seed=seed)
|
201 |
+
train = train_dev['train']
|
202 |
+
dev = train_dev['test']
|
203 |
+
return train, dev, test
|
204 |
+
|
205 |
+
# pdb.set_trace()
|
206 |
+
# P1tony_train, P1tony_dev, P1tony_test = train_dev_test_split(P1tony_dataset, dev_rate=0.5, test_rate=0.5, seed=1)
|
207 |
+
# P1tony_train_ = concatenate_datasets([P1tony_train,P1tony_scripted])
|
208 |
+
# pdb.set_trace()
|
209 |
+
|
210 |
+
Michael_52_train, Michael_52_dev, Michael_52_test = train_dev_test_split(P3Micheal_dataset_52, dev_rate=0.10, test_rate=0.1, seed=1)
|
211 |
+
|
212 |
+
# train_dev / test
|
213 |
+
ds = src_dataset.train_test_split(test_size=0.1, seed=1)
|
214 |
+
|
215 |
+
# dataset_libri = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
216 |
+
|
217 |
+
train_dev = ds['train']
|
218 |
+
# train / dev
|
219 |
+
train_dev = train_dev.train_test_split(test_size=int(len(src_dataset)*0.1), seed=1)
|
220 |
+
# train/dev/test
|
221 |
+
train = train_dev['train']
|
222 |
+
test = ds['test']
|
223 |
+
dev = train_dev['test']
|
224 |
+
|
225 |
+
encoded_train = train.map(prepare_dataset, num_proc=4)
|
226 |
+
encoded_dev = dev.map(prepare_dataset, num_proc=4)
|
227 |
+
encoded_test = test.map(prepare_dataset, num_proc=4)
|
228 |
+
p326_encoded_train = p326_300_dataset.map(prepare_dataset, num_proc=4)
|
229 |
+
|
230 |
+
# combine large p326 in to training set
|
231 |
+
encoded_train = concatenate_datasets([encoded_train, p326_encoded_train])
|
232 |
+
|
233 |
+
# encoded_healthy = healthy_test_dataset.map(prepare_dataset, num_proc=4)
|
234 |
+
# encoded_Fary = Fary_PAL_test_dataset.map(prepare_dataset, num_proc=4)
|
235 |
+
# encoded_John_p326 = John_p326_test_dataset.map(prepare_dataset, num_proc=4)
|
236 |
+
# encoded_John_video = John_video_test_dataset.map(prepare_dataset, num_proc=4)
|
237 |
+
|
238 |
+
# encoded_P1tony_train = P1tony_train.map(prepare_dataset, num_proc=4)
|
239 |
+
# encoded_P1tony_dev = P1tony_dev.map(prepare_dataset, num_proc=4)
|
240 |
+
# encoded_P1tony_test = P1tony_test.map(prepare_dataset, num_proc=4)
|
241 |
+
|
242 |
+
# pdb.set_trace()
|
243 |
+
import numpy as np
|
244 |
+
|
245 |
+
WER = evaluate.load("wer")
|
246 |
+
|
247 |
+
## Whisper decoding
|
248 |
+
|
249 |
+
from transformers import WhisperProcessor, WhisperForConditionalGeneration, WhisperTokenizer, WhisperFeatureExtractor, Seq2SeqTrainingArguments, Seq2SeqTrainer, WhisperModel
|
250 |
+
processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
|
251 |
+
# model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium").to("cuda:0")
|
252 |
+
model = WhisperForConditionalGeneration.from_pretrained("./fine_tuned/whipser_medium_en_PAL300_step25_step2_VCTK/checkpoint-400", use_auth_token=True).to("cuda:0")
|
253 |
+
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-medium", language="English", task="transcribe")
|
254 |
+
|
255 |
+
from pathlib import Path
|
256 |
+
id = Path(fine_tuning_dir).stem
|
257 |
+
pdb.set_trace()
|
258 |
+
tokenizer.push_to_hub("KevinGeng/%s"%id)
|
259 |
+
# import pdb
|
260 |
+
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-medium")
|
261 |
+
|
262 |
+
def whisper_prepare_dataset(batch):
|
263 |
+
# load and resample audio data from 48 to 16kHz
|
264 |
+
audio = batch["audio"]
|
265 |
+
|
266 |
+
# compute log-Mel input features from input audio array
|
267 |
+
batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
|
268 |
+
|
269 |
+
# encode target text to label ids
|
270 |
+
batch["labels"] = tokenizer(batch["transcription"]).input_ids
|
271 |
+
return batch
|
272 |
+
|
273 |
+
torch.cuda.empty_cache()
|
274 |
+
|
275 |
+
def my_map_to_pred(batch):
|
276 |
+
# pdb.set_trace()
|
277 |
+
audio = batch["audio"]
|
278 |
+
input_features = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt").input_features
|
279 |
+
# batch["reference"] = whisper_processor.tokenizer._normalize(batch['text'])
|
280 |
+
batch["reference"] = processor.tokenizer._normalize(batch['transcription'])
|
281 |
+
|
282 |
+
with torch.no_grad():
|
283 |
+
# predicted_ids = whisper_model.generate(input_features.to("cuda"))[0]
|
284 |
+
predicted_ids = model.generate(input_features.to("cuda"))[0]
|
285 |
+
transcription = model.decode(predicted_ids)
|
286 |
+
batch["prediction"] = model.tokenizer._normalize(transcription)
|
287 |
+
return batch
|
288 |
+
|
289 |
+
import torch
|
290 |
+
|
291 |
+
from dataclasses import dataclass
|
292 |
+
from typing import Any, Dict, List, Union
|
293 |
+
|
294 |
+
@dataclass
|
295 |
+
class DataCollatorSpeechSeq2SeqWithPadding:
|
296 |
+
processor: Any
|
297 |
+
|
298 |
+
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
299 |
+
# split inputs and labels since they have to be of different lengths and need different padding methods
|
300 |
+
# first treat the audio inputs by simply returning torch tensors
|
301 |
+
input_features = [{"input_features": feature["input_features"]} for feature in features]
|
302 |
+
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
|
303 |
+
|
304 |
+
# get the tokenized label sequences
|
305 |
+
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
306 |
+
# pad the labels to max length
|
307 |
+
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
|
308 |
+
|
309 |
+
# replace padding with -100 to ignore loss correctly
|
310 |
+
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
|
311 |
+
|
312 |
+
# if bos token is appended in previous tokenization step,
|
313 |
+
# cut bos token here as it's append later anyways
|
314 |
+
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
|
315 |
+
labels = labels[:, 1:]
|
316 |
+
|
317 |
+
batch["labels"] = labels
|
318 |
+
|
319 |
+
return batch
|
320 |
+
|
321 |
+
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
|
322 |
+
|
323 |
+
def compute_metrics(pred):
|
324 |
+
pred_ids = pred.predictions
|
325 |
+
label_ids = pred.label_ids
|
326 |
+
|
327 |
+
# replace -100 with the pad_token_id
|
328 |
+
label_ids[label_ids == -100] = tokenizer.pad_token_id
|
329 |
+
|
330 |
+
# we do not want to group tokens when computing the metrics
|
331 |
+
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
332 |
+
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
|
333 |
+
|
334 |
+
wer = 100 * WER.compute(predictions=pred_str, references=label_str)
|
335 |
+
|
336 |
+
return {"wer": wer}
|
337 |
+
|
338 |
+
# whisper_train = train.map(whisper_prepare_dataset, num_proc=4)
|
339 |
+
# pdb.set_trace()
|
340 |
+
whisper_train_large = encoded_train.map(whisper_prepare_dataset, num_proc=4)
|
341 |
+
whisper_dev = dev.map(whisper_prepare_dataset, num_proc=4)
|
342 |
+
whisper_test = test.map(whisper_prepare_dataset, num_proc=4)
|
343 |
+
|
344 |
+
encoded_Michael_52_train = Michael_52_train.map(whisper_prepare_dataset, num_proc=4)
|
345 |
+
encoded_Michael_52_dev = Michael_52_dev.map(whisper_prepare_dataset, num_proc=4)
|
346 |
+
encoded_Michael_52_test = Michael_52_test.map(whisper_prepare_dataset, num_proc=4)
|
347 |
+
# pdb.set_trace()
|
348 |
+
# # Add scirtped tony
|
349 |
+
# encoded_P1tony_train = P1tony_train_.map(whisper_prepare_dataset, num_proc=4)
|
350 |
+
# encoded_P1tony_dev = P1tony_dev.map(whisper_prepare_dataset, num_proc=4)
|
351 |
+
# encoded_P1tony_test = P1tony_test.map(whisper_prepare_dataset, num_proc=4)
|
352 |
+
|
353 |
+
# encode_negel_152_train = negel_152_train.map(whisper_prepare_dataset, num_proc=4)
|
354 |
+
# encode_negel_152_dev = negel_152_dev.map(whisper_prepare_dataset, num_proc=4)
|
355 |
+
# encode_negel_152_test = negel_152_test.map(whisper_prepare_dataset, num_proc=4)
|
356 |
+
|
357 |
+
# encoded_train_large = concatenate_datasets([whisper_train_large, encode_negel_152_train])
|
358 |
+
# encoded_dev_large = concatenate_datasets([whisper_dev, encode_negel_152_dev])
|
359 |
+
|
360 |
+
pdb.set_trace()
|
361 |
+
torch.cuda.empty_cache()
|
362 |
+
|
363 |
+
training_args = Seq2SeqTrainingArguments(
|
364 |
+
output_dir=fine_tuning_dir, # change to a repo name of your choice
|
365 |
+
per_device_train_batch_size=8,
|
366 |
+
gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
|
367 |
+
learning_rate=1e-5,
|
368 |
+
warmup_steps=50,
|
369 |
+
max_steps=1000,
|
370 |
+
gradient_checkpointing=True,
|
371 |
+
fp16=True,
|
372 |
+
evaluation_strategy="steps",
|
373 |
+
save_strategy="steps",
|
374 |
+
per_device_eval_batch_size=8,
|
375 |
+
predict_with_generate=True,
|
376 |
+
generation_max_length=512,
|
377 |
+
save_steps=10,
|
378 |
+
eval_steps=10,
|
379 |
+
logging_steps=10,
|
380 |
+
report_to=["tensorboard"],
|
381 |
+
load_best_model_at_end=True,
|
382 |
+
metric_for_best_model="wer",
|
383 |
+
greater_is_better=False,
|
384 |
+
save_total_limit=5,
|
385 |
+
push_to_hub=False,
|
386 |
+
)
|
387 |
+
from transformers import Trainer, TrainingArguments, EarlyStoppingCallback
|
388 |
+
|
389 |
+
# pdb.set_trace()
|
390 |
+
# # from transformers.trainer.callbacks import TensorBoardCallback
|
391 |
+
# class EvalLoggingCallback(TrainerCallback):
|
392 |
+
# def on_evaluate(self, args, state, control, metrics, **kwargs):
|
393 |
+
# print(f"Eval loss: {metrics['eval_loss']:.4f}, Accuracy: {metrics['eval_wer']:.4f}")
|
394 |
+
|
395 |
+
# pdb.set_trace()
|
396 |
+
|
397 |
+
trainer = Seq2SeqTrainer(
|
398 |
+
args=training_args,
|
399 |
+
model=model,
|
400 |
+
train_dataset=encoded_Michael_52_train,
|
401 |
+
eval_dataset=encoded_Michael_52_dev,
|
402 |
+
data_collator=data_collator,
|
403 |
+
compute_metrics=compute_metrics,
|
404 |
+
tokenizer=processor.feature_extractor,
|
405 |
+
callbacks=[EarlyStoppingCallback(early_stopping_patience=10)],
|
406 |
+
|
407 |
+
)
|
408 |
+
# callbacks=[EvalLoggingCallback()]
|
409 |
+
trainer.train()
|
410 |
+
# trainer.evaluate(encoded_P1tony_test, metrix_key_prefix="test")
|
411 |
+
# trainer.callback_handler.on_test_end(trainer, datasets=encoded_P1tony_test)
|
412 |
+
|
413 |
+
|
414 |
+
# ## Not fine tuned
|
415 |
+
# z_result = encoded_test.map(my_map_to_pred)·
|
416 |
+
# # pdb.set_trace()
|
417 |
+
# # 0.4692737430167598
|
418 |
+
# z = WER.compute(references=z_result['reference'], predictions=z_result['prediction'])
|
419 |
+
|
420 |
+
# z_hel_result = encoded_healthy.map(my_map_to_pred)
|
421 |
+
# #
|
422 |
+
# z_hel = WER.compute(references=z_hel_result['reference'], predictions=z_hel_result['prediction'])
|
423 |
+
# # 0.1591610117211598
|
424 |
+
|
425 |
+
# z_fary_result = encoded_Fary.map(my_map_to_pred)
|
426 |
+
# z_far = WER.compute(references=z_fary_result['reference'], predictions=z_fary_result['prediction'])
|
427 |
+
# # 0.1791044776119403
|
428 |
+
|
429 |
+
|
430 |
+
# z_john_p326_result = encoded_John_p326.map(my_map_to_pred)
|
431 |
+
# z_john_p326 = WER.compute(references=z_john_p326_result['reference'], predictions=z_john_p326_result['prediction'])
|
432 |
+
# # 0.4648241206030151
|
433 |
+
|
434 |
+
# # y_John_video= fine_tuned_trainer.predict(encoded_John_video)
|
435 |
+
# # metrics={'test_loss': 2.665189743041992, 'test_wer': 0.7222222222222222, 'test_runtime': 0.1633, 'test_samples_per_second': 48.979, 'test_steps_per_second': 6.122})
|
436 |
+
# pdb.set_trace()
|
437 |
+
|
438 |
+
# p326 training
|
439 |
+
# metrics={'test_loss': 0.4804028868675232, 'test_wer': 0.21787709497206703, 'test_runtime': 0.3594, 'test_samples_per_second': 44.517, 'test_steps_per_second': 5.565})
|
440 |
+
# hel metrics={'test_loss': 1.6363693475723267, 'test_wer': 0.17951881554595928, 'test_runtime': 3.8451, 'test_samples_per_second': 41.611, 'test_steps_per_second': 5.201})
|
441 |
+
# Fary: metrics={'test_loss': 1.4633615016937256, 'test_wer': 0.5572139303482587, 'test_runtime': 0.6627, 'test_samples_per_second': 45.27, 'test_steps_per_second': 6.036})
|
442 |
+
# p326 large: metrics={'test_loss': 0.6568527817726135, 'test_wer': 0.2889447236180904, 'test_runtime': 0.7169, 'test_samples_per_second': 51.613, 'test_steps_per_second': 6.975})
|
local/whisper_fine_tuning_negel.py
ADDED
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fine_tuning_dir = "fine_tuned/SSD/model/Negel_152_AVA_script_conv_train_conv_dev"
|
2 |
+
"""
|
3 |
+
TODO:
|
4 |
+
+ [x] Load Configuration
|
5 |
+
+ [ ] Multi ASR Engine
|
6 |
+
+ [ ] Batch / Real Time support
|
7 |
+
"""
|
8 |
+
from pathlib import Path
|
9 |
+
from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC, AutoProcessor
|
10 |
+
from datasets import load_dataset, concatenate_datasets
|
11 |
+
from datasets import Dataset, Audio
|
12 |
+
import pdb
|
13 |
+
import string
|
14 |
+
import librosa
|
15 |
+
# local import
|
16 |
+
import sys
|
17 |
+
|
18 |
+
sys.path.append("src")
|
19 |
+
import torch
|
20 |
+
torch.cuda.set_device("cuda:0")
|
21 |
+
# token_model = AutoModelForCTC.from_pretrained(
|
22 |
+
# "facebook/wav2vec2-base-960h"
|
23 |
+
# )
|
24 |
+
|
25 |
+
# audio_dir= "/Users/kevingeng/Laronix/laronix_automos/data/Patient_sil_trim_16k_normed_5_snr_40/"
|
26 |
+
audio_dir ="./data/Patient_sil_trim_16k_normed_5_snr_40"
|
27 |
+
healthy_dir="./data/Healthy"
|
28 |
+
Fary_PAL_30="./data/Fary_PAL_p326_20230110_30"
|
29 |
+
John_p326 = "./data/John_p326/output"
|
30 |
+
John_video = "./data/20230103_video"
|
31 |
+
p326_300_dir ="./data/John_p326_large"
|
32 |
+
|
33 |
+
negel_152 = "./data/4_negal_152_clean_all"
|
34 |
+
|
35 |
+
patient_T = "data/Patient_T/Patient_T"
|
36 |
+
patient_L = "data/Patient_L/Patient_L"
|
37 |
+
P1tony = "data/Participant1_Tony_Recording/CLEAN_SENTENCES/CONVERSATIONAL/PAL"
|
38 |
+
P1tony_arthur = "data/Participant1_Tony_Recording/CLEAN_SENTENCES/SCRIPTED/Arthur_the_Rat/PAL"
|
39 |
+
P1tony_rainbow = "data/Participant1_Tony_Recording/CLEAN_SENTENCES/SCRIPTED/Rainbow_Passage/Laronix"
|
40 |
+
|
41 |
+
def dataclean(example):
|
42 |
+
# pdb.set_trace()
|
43 |
+
if example['audio']['sampling_rate'] != 16000:
|
44 |
+
resampled_audio = librosa.resample(y=example['audio']['array'],
|
45 |
+
orig_sr= example['audio']['sampling_rate'],
|
46 |
+
target_sr=16000)
|
47 |
+
# torchaudio.transforms.Resample(example['audio']['sampling_rate'], 16000)
|
48 |
+
# resampled_audio = resampler(example['audio']['array'])
|
49 |
+
|
50 |
+
return {"audio": {"path": example['audio']['path'], "array": resampled_audio, "sampling_rate": 16000},
|
51 |
+
"transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))}
|
52 |
+
else:
|
53 |
+
return {"transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))}
|
54 |
+
|
55 |
+
# patient_L_test_dataset = load_dataset("audiofolder", data_dir=patient_L, split="train")
|
56 |
+
# patient_L_test_dataset = patient_L_test_dataset.map(dataclean)
|
57 |
+
|
58 |
+
# patient_T_test_dataset = load_dataset("audiofolder", data_dir=patient_T, split="train")
|
59 |
+
# patient_T_test_dataset = patient_T_test_dataset.map(dataclean)
|
60 |
+
|
61 |
+
P1tony_dataset = load_dataset("audiofolder", data_dir=P1tony, split="train")
|
62 |
+
P1tony_dataset = P1tony_dataset.map(dataclean)
|
63 |
+
|
64 |
+
negel_152_dataset = load_dataset("audiofolder", data_dir=negel_152, split="train")
|
65 |
+
negel_152_dataset = negel_152_dataset.map(dataclean)
|
66 |
+
# pdb.set_trace()
|
67 |
+
# P1tony_scripted1 = load_dataset("audiofolder", data_dir=P1tony_rainbow, split="train")
|
68 |
+
# P1tony_scripted2 = load_dataset("audiofolder", data_dir=P1tony_arthur, split="train")
|
69 |
+
# P1tony_scripted1 = P1tony_scripted1.map(dataclean)
|
70 |
+
# P1tony_scripted2 = P1tony_scripted2.map(dataclean)
|
71 |
+
# P1tony_scripted = concatenate_datasets([P1tony_scripted1, P1tony_scripted2])
|
72 |
+
|
73 |
+
# audio_dir ="/home/kevingeng/laronix/laronix_automos/data/Healthy"
|
74 |
+
# tgt_audio_dir= "/Users/kevingeng/Laronix/Dataset/Pneumatic/automos"
|
75 |
+
|
76 |
+
# Get Transcription, WER and PPM
|
77 |
+
"""
|
78 |
+
TODO:
|
79 |
+
[DONE]: Automatic generating Config
|
80 |
+
"""
|
81 |
+
|
82 |
+
import yaml
|
83 |
+
import argparse
|
84 |
+
import sys
|
85 |
+
from pathlib import Path
|
86 |
+
|
87 |
+
sys.path.append("./src")
|
88 |
+
import lightning_module
|
89 |
+
from UV import plot_UV, get_speech_interval
|
90 |
+
from transformers import pipeline
|
91 |
+
from rich.progress import track
|
92 |
+
from rich import print as rprint
|
93 |
+
import numpy as np
|
94 |
+
import jiwer
|
95 |
+
import pdb
|
96 |
+
import torch.nn as nn
|
97 |
+
import torch
|
98 |
+
import torchaudio
|
99 |
+
import gradio as gr
|
100 |
+
from sys import flags
|
101 |
+
from random import sample
|
102 |
+
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
103 |
+
|
104 |
+
import evaluate
|
105 |
+
|
106 |
+
wer = evaluate.load("wer")
|
107 |
+
|
108 |
+
# root_path = Path(__file__).parents[1]
|
109 |
+
|
110 |
+
class ChangeSampleRate(nn.Module):
|
111 |
+
def __init__(self, input_rate: int, output_rate: int):
|
112 |
+
super().__init__()
|
113 |
+
self.output_rate = output_rate
|
114 |
+
self.input_rate = input_rate
|
115 |
+
|
116 |
+
def forward(self, wav: torch.tensor) -> torch.tensor:
|
117 |
+
# Only accepts 1-channel waveform input
|
118 |
+
wav = wav.view(wav.size(0), -1)
|
119 |
+
new_length = wav.size(-1) * self.output_rate // self.input_rate
|
120 |
+
indices = torch.arange(new_length) * (
|
121 |
+
self.input_rate / self.output_rate
|
122 |
+
)
|
123 |
+
round_down = wav[:, indices.long()]
|
124 |
+
round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
|
125 |
+
output = round_down * (1.0 - indices.fmod(1.0)).unsqueeze(
|
126 |
+
0
|
127 |
+
) + round_up * indices.fmod(1.0).unsqueeze(0)
|
128 |
+
return output
|
129 |
+
|
130 |
+
# resample and clean text data
|
131 |
+
def dataclean(example):
|
132 |
+
# pdb.set_trace()
|
133 |
+
if example['audio']['sampling_rate'] != 16000:
|
134 |
+
resampled_audio = librosa.resample(y=example['audio']['array'],
|
135 |
+
orig_sr= example['audio']['sampling_rate'],
|
136 |
+
target_sr=16000)
|
137 |
+
# torchaudio.transforms.Resample(example['audio']['sampling_rate'], 16000)
|
138 |
+
# resampled_audio = resampler(example['audio']['array'])
|
139 |
+
|
140 |
+
return {"audio": {"path": example['audio']['path'], "array": resampled_audio, "sampling_rate": 16000},
|
141 |
+
"transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))}
|
142 |
+
else:
|
143 |
+
return {"transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))}
|
144 |
+
|
145 |
+
# processor = AutoFeatureExtractor.from_pretrained(
|
146 |
+
# "facebook/wav2vec2-base-960h"
|
147 |
+
# )
|
148 |
+
processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
|
149 |
+
|
150 |
+
def prepare_dataset(batch):
|
151 |
+
audio = batch["audio"]
|
152 |
+
batch = processor(audio["array"], sampling_rate = audio["sampling_rate"], text=batch['transcription'])
|
153 |
+
batch["input_length"] = len(batch["input_values"][0])
|
154 |
+
return batch
|
155 |
+
|
156 |
+
# src_dataset = load_dataset("audiofolder", data_dir=audio_dir, split="train")
|
157 |
+
# src_dataset = src_dataset.map(dataclean)
|
158 |
+
# p326_300_dataset = load_dataset("audiofolder", data_dir=p326_300_dir, split="train")
|
159 |
+
# p326_300_dataset = p326_300_dataset.map(dataclean)
|
160 |
+
|
161 |
+
|
162 |
+
# healthy_test_dataset = load_dataset("audiofolder", data_dir=healthy_dir, split='train')
|
163 |
+
# healthy_test_dataset = healthy_test_dataset.map(dataclean)
|
164 |
+
|
165 |
+
# Fary_PAL_test_dataset = load_dataset("audiofolder", data_dir=Fary_PAL_30, split='train')
|
166 |
+
# Fary_PAL_test_dataset = Fary_PAL_test_dataset.map(dataclean)
|
167 |
+
|
168 |
+
# John_p326_test_dataset = load_dataset("audiofolder", data_dir=John_p326, split='train')
|
169 |
+
# John_p326_test_dataset = John_p326_test_dataset.map(dataclean)
|
170 |
+
|
171 |
+
# John_video_test_dataset = load_dataset("audiofolder", data_dir=John_video, split='train')
|
172 |
+
# John_video_test_dataset = John_video_test_dataset.map(dataclean)
|
173 |
+
|
174 |
+
|
175 |
+
|
176 |
+
# pdb.set_trace()
|
177 |
+
|
178 |
+
def train_dev_test_split(dataset: Dataset, dev_rate=0.1, test_rate=0.1, seed=1):
|
179 |
+
"""
|
180 |
+
input: dataset
|
181 |
+
dev_rate,
|
182 |
+
test_rate
|
183 |
+
seed
|
184 |
+
-------
|
185 |
+
Output:
|
186 |
+
dataset_dict{"train", "dev", "test"}
|
187 |
+
"""
|
188 |
+
train_dev_test = dataset.train_test_split(test_size=test_rate, seed=seed)
|
189 |
+
test = train_dev_test["test"]
|
190 |
+
train_dev = train_dev_test['train']
|
191 |
+
|
192 |
+
# pdb.set_trace()
|
193 |
+
if len(train_dev) <= int(len(dataset)*dev_rate):
|
194 |
+
train = Dataset.from_dict({"audio": [], "transcription": []})
|
195 |
+
dev = train_dev
|
196 |
+
else:
|
197 |
+
train_dev = train_dev.train_test_split(test_size=int(len(dataset)*dev_rate), seed=seed)
|
198 |
+
train = train_dev['train']
|
199 |
+
dev = train_dev['test']
|
200 |
+
return train, dev, test
|
201 |
+
|
202 |
+
# pdb.set_trace()
|
203 |
+
# P1tony_train, P1tony_dev, P1tony_test = train_dev_test_split(P1tony_dataset, dev_rate=0.5, test_rate=0.5, seed=1)
|
204 |
+
# P1tony_train_ = concatenate_datasets([P1tony_train,P1tony_scripted])
|
205 |
+
# pdb.set_trace()
|
206 |
+
|
207 |
+
negel_152_train, negel_152_dev, negel_152_test = train_dev_test_split(negel_152_dataset, dev_rate=0.1, test_rate=0.1, seed=1)
|
208 |
+
|
209 |
+
# train_dev / test
|
210 |
+
# ds = src_dataset.train_test_split(test_size=0.1, seed=1)
|
211 |
+
|
212 |
+
# dataset_libri = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
213 |
+
|
214 |
+
# train_dev = ds['train']
|
215 |
+
# # train / dev
|
216 |
+
# train_dev = train_dev.train_test_split(test_size=int(len(src_dataset)*0.1), seed=1)
|
217 |
+
# # train/dev/test
|
218 |
+
# train = train_dev['train']
|
219 |
+
# test = ds['test']
|
220 |
+
# dev = train_dev['test']
|
221 |
+
|
222 |
+
# encoded_train = train.map(prepare_dataset, num_proc=4)
|
223 |
+
# encoded_dev = dev.map(prepare_dataset, num_proc=4)
|
224 |
+
# encoded_test = test.map(prepare_dataset, num_proc=4)
|
225 |
+
# p326_encoded_train = p326_300_dataset.map(prepare_dataset, num_proc=4)
|
226 |
+
|
227 |
+
# # combine large p326 in to training set
|
228 |
+
# # encoded_train = concatenate_datasets([encoded_train, p326_encoded_train])
|
229 |
+
|
230 |
+
# encoded_healthy = healthy_test_dataset.map(prepare_dataset, num_proc=4)
|
231 |
+
# encoded_Fary = Fary_PAL_test_dataset.map(prepare_dataset, num_proc=4)
|
232 |
+
# encoded_John_p326 = John_p326_test_dataset.map(prepare_dataset, num_proc=4)
|
233 |
+
# encoded_John_video = John_video_test_dataset.map(prepare_dataset, num_proc=4)
|
234 |
+
|
235 |
+
# encoded_P1tony_train = P1tony_train.map(prepare_dataset, num_proc=4)
|
236 |
+
# encoded_P1tony_dev = P1tony_dev.map(prepare_dataset, num_proc=4)
|
237 |
+
# encoded_P1tony_test = P1tony_test.map(prepare_dataset, num_proc=4)
|
238 |
+
|
239 |
+
# pdb.set_trace()
|
240 |
+
import numpy as np
|
241 |
+
|
242 |
+
WER = evaluate.load("wer")
|
243 |
+
|
244 |
+
## Whisper decoding
|
245 |
+
|
246 |
+
from transformers import WhisperProcessor, WhisperForConditionalGeneration, WhisperTokenizer, WhisperFeatureExtractor, Seq2SeqTrainingArguments, Seq2SeqTrainer, WhisperModel
|
247 |
+
processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
|
248 |
+
# model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium").to("cuda:0")
|
249 |
+
model = WhisperForConditionalGeneration.from_pretrained("./fine_tuned/whipser_medium_en_PAL300_step25_step2_VCTK/checkpoint-400", use_auth_token=True).to("cuda:0")
|
250 |
+
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-medium", language="English", task="transcribe")
|
251 |
+
|
252 |
+
from pathlib import Path
|
253 |
+
id = Path(fine_tuning_dir).stem
|
254 |
+
# pdb.set_trace()
|
255 |
+
tokenizer.push_to_hub("KevinGeng/%s"%id)
|
256 |
+
# import pdb
|
257 |
+
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-medium")
|
258 |
+
|
259 |
+
def whisper_prepare_dataset(batch):
|
260 |
+
# load and resample audio data from 48 to 16kHz
|
261 |
+
audio = batch["audio"]
|
262 |
+
|
263 |
+
# compute log-Mel input features from input audio array
|
264 |
+
batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
|
265 |
+
|
266 |
+
# encode target text to label ids
|
267 |
+
batch["labels"] = tokenizer(batch["transcription"]).input_ids
|
268 |
+
return batch
|
269 |
+
|
270 |
+
torch.cuda.empty_cache()
|
271 |
+
|
272 |
+
def my_map_to_pred(batch):
|
273 |
+
# pdb.set_trace()
|
274 |
+
audio = batch["audio"]
|
275 |
+
input_features = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt").input_features
|
276 |
+
# batch["reference"] = whisper_processor.tokenizer._normalize(batch['text'])
|
277 |
+
batch["reference"] = processor.tokenizer._normalize(batch['transcription'])
|
278 |
+
|
279 |
+
with torch.no_grad():
|
280 |
+
# predicted_ids = whisper_model.generate(input_features.to("cuda"))[0]
|
281 |
+
predicted_ids = model.generate(input_features.to("cuda"))[0]
|
282 |
+
transcription = model.decode(predicted_ids)
|
283 |
+
batch["prediction"] = model.tokenizer._normalize(transcription)
|
284 |
+
return batch
|
285 |
+
|
286 |
+
import torch
|
287 |
+
|
288 |
+
from dataclasses import dataclass
|
289 |
+
from typing import Any, Dict, List, Union
|
290 |
+
|
291 |
+
@dataclass
|
292 |
+
class DataCollatorSpeechSeq2SeqWithPadding:
|
293 |
+
processor: Any
|
294 |
+
|
295 |
+
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
296 |
+
# split inputs and labels since they have to be of different lengths and need different padding methods
|
297 |
+
# first treat the audio inputs by simply returning torch tensors
|
298 |
+
input_features = [{"input_features": feature["input_features"]} for feature in features]
|
299 |
+
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
|
300 |
+
|
301 |
+
# get the tokenized label sequences
|
302 |
+
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
303 |
+
# pad the labels to max length
|
304 |
+
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
|
305 |
+
|
306 |
+
# replace padding with -100 to ignore loss correctly
|
307 |
+
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
|
308 |
+
|
309 |
+
# if bos token is appended in previous tokenization step,
|
310 |
+
# cut bos token here as it's append later anyways
|
311 |
+
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
|
312 |
+
labels = labels[:, 1:]
|
313 |
+
|
314 |
+
batch["labels"] = labels
|
315 |
+
|
316 |
+
return batch
|
317 |
+
|
318 |
+
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
|
319 |
+
|
320 |
+
def compute_metrics(pred):
|
321 |
+
pred_ids = pred.predictions
|
322 |
+
label_ids = pred.label_ids
|
323 |
+
|
324 |
+
# replace -100 with the pad_token_id
|
325 |
+
label_ids[label_ids == -100] = tokenizer.pad_token_id
|
326 |
+
|
327 |
+
# we do not want to group tokens when computing the metrics
|
328 |
+
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
329 |
+
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
|
330 |
+
|
331 |
+
wer = 100 * WER.compute(predictions=pred_str, references=label_str)
|
332 |
+
|
333 |
+
return {"wer": wer}
|
334 |
+
|
335 |
+
# whisper_train = train.map(whisper_prepare_dataset, num_proc=4)
|
336 |
+
# pdb.set_trace()
|
337 |
+
# whisper_train_large = encoded_train.map(whisper_prepare_dataset, num_proc=4)
|
338 |
+
# whisper_dev = dev.map(whisper_prepare_dataset, num_proc=4)
|
339 |
+
# whisper_test = test.map(whisper_prepare_dataset, num_proc=4)
|
340 |
+
# pdb.set_trace()
|
341 |
+
# # Add scirtped tony
|
342 |
+
# encoded_P1tony_train = P1tony_train_.map(whisper_prepare_dataset, num_proc=4)
|
343 |
+
# encoded_P1tony_dev = P1tony_dev.map(whisper_prepare_dataset, num_proc=4)
|
344 |
+
# encoded_P1tony_test = P1tony_test.map(whisper_prepare_dataset, num_proc=4)
|
345 |
+
|
346 |
+
encode_negel_152_train = negel_152_train.map(whisper_prepare_dataset, num_proc=4)
|
347 |
+
encode_negel_152_dev = negel_152_dev.map(whisper_prepare_dataset, num_proc=4)
|
348 |
+
encode_negel_152_test = negel_152_test.map(whisper_prepare_dataset, num_proc=4)
|
349 |
+
pdb.set_trace()
|
350 |
+
torch.cuda.empty_cache()
|
351 |
+
|
352 |
+
training_args = Seq2SeqTrainingArguments(
|
353 |
+
output_dir=fine_tuning_dir, # change to a repo name of your choice
|
354 |
+
per_device_train_batch_size=8,
|
355 |
+
gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
|
356 |
+
learning_rate=1e-5,
|
357 |
+
warmup_steps=50,
|
358 |
+
max_steps=1000,
|
359 |
+
gradient_checkpointing=True,
|
360 |
+
fp16=True,
|
361 |
+
evaluation_strategy="steps",
|
362 |
+
save_strategy="steps",
|
363 |
+
per_device_eval_batch_size=8,
|
364 |
+
predict_with_generate=True,
|
365 |
+
generation_max_length=512,
|
366 |
+
save_steps=10,
|
367 |
+
eval_steps=10,
|
368 |
+
logging_steps=10,
|
369 |
+
report_to=["tensorboard"],
|
370 |
+
load_best_model_at_end=True,
|
371 |
+
metric_for_best_model="wer",
|
372 |
+
greater_is_better=False,
|
373 |
+
save_total_limit=5,
|
374 |
+
push_to_hub=True,
|
375 |
+
)
|
376 |
+
from transformers import Trainer, TrainingArguments, EarlyStoppingCallback
|
377 |
+
|
378 |
+
# pdb.set_trace()
|
379 |
+
# # from transformers.trainer.callbacks import TensorBoardCallback
|
380 |
+
# class EvalLoggingCallback(TrainerCallback):
|
381 |
+
# def on_evaluate(self, args, state, control, metrics, **kwargs):
|
382 |
+
# print(f"Eval loss: {metrics['eval_loss']:.4f}, Accuracy: {metrics['eval_wer']:.4f}")
|
383 |
+
|
384 |
+
# pdb.set_trace()
|
385 |
+
|
386 |
+
trainer = Seq2SeqTrainer(
|
387 |
+
args=training_args,
|
388 |
+
model=model,
|
389 |
+
train_dataset=encode_negel_152_train,
|
390 |
+
eval_dataset=encode_negel_152_dev,
|
391 |
+
data_collator=data_collator,
|
392 |
+
compute_metrics=compute_metrics,
|
393 |
+
tokenizer=processor.feature_extractor,
|
394 |
+
callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
|
395 |
+
|
396 |
+
)
|
397 |
+
# callbacks=[EvalLoggingCallback()]
|
398 |
+
trainer.train()
|
399 |
+
# trainer.evaluate(encoded_P1tony_test, metrix_key_prefix="test")
|
400 |
+
# trainer.callback_handler.on_test_end(trainer, datasets=encoded_P1tony_test)
|
401 |
+
|
402 |
+
|
403 |
+
# ## Not fine tuned
|
404 |
+
# z_result = encoded_test.map(my_map_to_pred)·
|
405 |
+
# # pdb.set_trace()
|
406 |
+
# # 0.4692737430167598
|
407 |
+
# z = WER.compute(references=z_result['reference'], predictions=z_result['prediction'])
|
408 |
+
|
409 |
+
# z_hel_result = encoded_healthy.map(my_map_to_pred)
|
410 |
+
# #
|
411 |
+
# z_hel = WER.compute(references=z_hel_result['reference'], predictions=z_hel_result['prediction'])
|
412 |
+
# # 0.1591610117211598
|
413 |
+
|
414 |
+
# z_fary_result = encoded_Fary.map(my_map_to_pred)
|
415 |
+
# z_far = WER.compute(references=z_fary_result['reference'], predictions=z_fary_result['prediction'])
|
416 |
+
# # 0.1791044776119403
|
417 |
+
|
418 |
+
|
419 |
+
# z_john_p326_result = encoded_John_p326.map(my_map_to_pred)
|
420 |
+
# z_john_p326 = WER.compute(references=z_john_p326_result['reference'], predictions=z_john_p326_result['prediction'])
|
421 |
+
# # 0.4648241206030151
|
422 |
+
|
423 |
+
# # y_John_video= fine_tuned_trainer.predict(encoded_John_video)
|
424 |
+
# # metrics={'test_loss': 2.665189743041992, 'test_wer': 0.7222222222222222, 'test_runtime': 0.1633, 'test_samples_per_second': 48.979, 'test_steps_per_second': 6.122})
|
425 |
+
# pdb.set_trace()
|
426 |
+
|
427 |
+
# p326 training
|
428 |
+
# metrics={'test_loss': 0.4804028868675232, 'test_wer': 0.21787709497206703, 'test_runtime': 0.3594, 'test_samples_per_second': 44.517, 'test_steps_per_second': 5.565})
|
429 |
+
# hel metrics={'test_loss': 1.6363693475723267, 'test_wer': 0.17951881554595928, 'test_runtime': 3.8451, 'test_samples_per_second': 41.611, 'test_steps_per_second': 5.201})
|
430 |
+
# Fary: metrics={'test_loss': 1.4633615016937256, 'test_wer': 0.5572139303482587, 'test_runtime': 0.6627, 'test_samples_per_second': 45.27, 'test_steps_per_second': 6.036})
|
431 |
+
# p326 large: metrics={'test_loss': 0.6568527817726135, 'test_wer': 0.2889447236180904, 'test_runtime': 0.7169, 'test_samples_per_second': 51.613, 'test_steps_per_second': 6.975})
|
local/whisper_fine_tuning_negel_decode.py
ADDED
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fine_tuning_dir = "fine_tuned/SSD/model/Michael_100_with_Large_AVA_script_conv_train_conv_dev/checkpoint-100"
|
2 |
+
# fine_tuning_dir = "fine_tuned/SSD/model/Michael_52_with_Large_AVA_script_conv_train_conv_dev/checkpoint-60"
|
3 |
+
# fine_tuning_dir = "fine_tuned/SSD/model/Negel_152_AVA_script_conv_train_conv_dev/checkpoint-100"
|
4 |
+
# fine_tuning_dir = "fine_tuned/SSD/model/Tony1_AVA_script_conv_train_conv_dev/checkpoint-160"
|
5 |
+
# fine_tuning_dir = "fine_tuned/SSD/model/Negel_with_Large_AVA_script_conv_train_conv_dev/checkpoint-210"
|
6 |
+
|
7 |
+
"""
|
8 |
+
TODO:
|
9 |
+
+ [x] Whipser Fine Tuned Model Evalutation
|
10 |
+
+ [ ]
|
11 |
+
+ [ ] Batch / Real Time support
|
12 |
+
"""
|
13 |
+
from typing import Any, Dict, List, Union
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from transformers import Seq2SeqTrainer
|
16 |
+
from transformers import WhisperProcessor, WhisperForConditionalGeneration, WhisperTokenizer, WhisperFeatureExtractor, Seq2SeqTrainingArguments, Seq2SeqTrainer, WhisperModel
|
17 |
+
import evaluate
|
18 |
+
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
19 |
+
from random import sample
|
20 |
+
from sys import flags
|
21 |
+
import gradio as gr
|
22 |
+
import torchaudio
|
23 |
+
import torch.nn as nn
|
24 |
+
import jiwer
|
25 |
+
import numpy as np
|
26 |
+
from rich import print as rprint
|
27 |
+
from rich.progress import track
|
28 |
+
from transformers import pipeline
|
29 |
+
import argparse
|
30 |
+
import yaml
|
31 |
+
import torch
|
32 |
+
from pathlib import Path
|
33 |
+
from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC, AutoProcessor
|
34 |
+
from datasets import load_dataset, concatenate_datasets
|
35 |
+
from datasets import Dataset, Audio
|
36 |
+
import pdb
|
37 |
+
import string
|
38 |
+
import librosa
|
39 |
+
# local import
|
40 |
+
import sys
|
41 |
+
|
42 |
+
sys.path.append("src")
|
43 |
+
import lightning_module
|
44 |
+
|
45 |
+
torch.cuda.set_device("cuda:0")
|
46 |
+
|
47 |
+
audio_dir = "./data/Patient_sil_trim_16k_normed_5_snr_40"
|
48 |
+
healthy_dir = "./data/Healthy"
|
49 |
+
Fary_PAL_30 = "./data/Fary_PAL_p326_20230110_30"
|
50 |
+
John_p326 = "./data/John_p326/output"
|
51 |
+
John_video = "./data/20230103_video"
|
52 |
+
negel_79 = "./data/4_negal_152_clean_all"
|
53 |
+
negel_152 = "./data/4_negal_152_clean_all"
|
54 |
+
P1tony = "data/Participant1_Tony_Recording/CLEAN_SENTENCES/CONVERSATIONAL/PAL"
|
55 |
+
|
56 |
+
michael3_52 = "data/3_michael_20230619_100"
|
57 |
+
|
58 |
+
patient_T = "data/Patient_T/Patient_T"
|
59 |
+
patient_L = "data/Patient_L/Patient_L"
|
60 |
+
# Get Transcription, WER and PPM
|
61 |
+
"""
|
62 |
+
TODO:
|
63 |
+
[DONE]: Automatic generating Config
|
64 |
+
"""
|
65 |
+
|
66 |
+
|
67 |
+
sys.path.append("./src")
|
68 |
+
|
69 |
+
wer = evaluate.load("wer")
|
70 |
+
|
71 |
+
# root_path = Path(__file__).parents[1]
|
72 |
+
|
73 |
+
class ChangeSampleRate(nn.Module):
|
74 |
+
def __init__(self, input_rate: int, output_rate: int):
|
75 |
+
super().__init__()
|
76 |
+
self.output_rate = output_rate
|
77 |
+
self.input_rate = input_rate
|
78 |
+
|
79 |
+
def forward(self, wav: torch.tensor) -> torch.tensor:
|
80 |
+
# Only accepts 1-channel waveform input
|
81 |
+
wav = wav.view(wav.size(0), -1)
|
82 |
+
new_length = wav.size(-1) * self.output_rate // self.input_rate
|
83 |
+
indices = torch.arange(new_length) * (
|
84 |
+
self.input_rate / self.output_rate
|
85 |
+
)
|
86 |
+
round_down = wav[:, indices.long()]
|
87 |
+
round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
|
88 |
+
output = round_down * (1.0 - indices.fmod(1.0)).unsqueeze(
|
89 |
+
0
|
90 |
+
) + round_up * indices.fmod(1.0).unsqueeze(0)
|
91 |
+
return output
|
92 |
+
|
93 |
+
# resample and clean text data
|
94 |
+
|
95 |
+
|
96 |
+
def dataclean(example):
|
97 |
+
# pdb.set_trace()
|
98 |
+
if example['audio']['sampling_rate'] != 16000:
|
99 |
+
resampled_audio = librosa.resample(y=example['audio']['array'],
|
100 |
+
orig_sr=example['audio']['sampling_rate'],
|
101 |
+
target_sr=16000)
|
102 |
+
# torchaudio.transforms.Resample(example['audio']['sampling_rate'], 16000)
|
103 |
+
# resampled_audio = resampler(example['audio']['array'])
|
104 |
+
|
105 |
+
return {"audio": {"path": example['audio']['path'], "array": resampled_audio, "sampling_rate": 16000},
|
106 |
+
"transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))}
|
107 |
+
else:
|
108 |
+
return {"transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))}
|
109 |
+
|
110 |
+
processor = AutoFeatureExtractor.from_pretrained(
|
111 |
+
"facebook/wav2vec2-base-960h"
|
112 |
+
)
|
113 |
+
|
114 |
+
def prepare_dataset(batch):
|
115 |
+
audio = batch["audio"]
|
116 |
+
batch = processor(
|
117 |
+
audio["array"], sampling_rate=audio["sampling_rate"], text=batch['transcription'])
|
118 |
+
batch["input_length"] = len(batch["input_values"][0])
|
119 |
+
return batch
|
120 |
+
|
121 |
+
|
122 |
+
negel_79_dataset = load_dataset("audiofolder", data_dir=negel_79, split="train")
|
123 |
+
negel_79_dataset = negel_79_dataset.map(dataclean)
|
124 |
+
|
125 |
+
def train_dev_test_split(dataset: Dataset, dev_rate=0.1, test_rate=0.1, seed=1):
|
126 |
+
"""
|
127 |
+
input: dataset
|
128 |
+
dev_rate,
|
129 |
+
test_rate
|
130 |
+
seed
|
131 |
+
-------
|
132 |
+
Output:
|
133 |
+
dataset_dict{"train", "dev", "test"}
|
134 |
+
"""
|
135 |
+
train_dev_test = dataset.train_test_split(test_size=test_rate, seed=seed)
|
136 |
+
test = train_dev_test["test"]
|
137 |
+
train_dev = train_dev_test['train']
|
138 |
+
|
139 |
+
# pdb.set_trace()
|
140 |
+
if len(train_dev) <= int(len(dataset)*dev_rate):
|
141 |
+
train = Dataset.from_dict({"audio": [], "transcription": []})
|
142 |
+
dev = train_dev
|
143 |
+
else:
|
144 |
+
train_dev = train_dev.train_test_split(test_size=int(len(dataset)*dev_rate), seed=seed)
|
145 |
+
train = train_dev['train']
|
146 |
+
dev = train_dev['test']
|
147 |
+
return train, dev, test
|
148 |
+
P1tony_dataset = load_dataset("audiofolder", data_dir=P1tony, split="train")
|
149 |
+
P1tony_dataset = P1tony_dataset.map(dataclean)
|
150 |
+
# pdb.set_trace()
|
151 |
+
P1tony_train, P1tony_dev, P1tony_test = train_dev_test_split(P1tony_dataset, dev_rate=0.5, test_rate=0.5, seed=1)
|
152 |
+
|
153 |
+
P3Micheal_dataset_52 = load_dataset("audiofolder", data_dir=michael3_52, split="train")
|
154 |
+
P3Micheal_dataset_52 = P3Micheal_dataset_52.map(dataclean)
|
155 |
+
|
156 |
+
Michael_52_train, Michael_52_dev, Michael_52_test = train_dev_test_split(P3Micheal_dataset_52, dev_rate=0.1, test_rate=0.1, seed=1)
|
157 |
+
|
158 |
+
# P1tony_train_ = concatenate_datasets([P1tony_train,P1tony_scripted])
|
159 |
+
# pdb.set_trace()
|
160 |
+
|
161 |
+
# Negel_79_train, Negel_79_dev, Negel_79_test = train_dev_test_split(negel_79_dataset, dev_rate=0.1, test_rate=0.1, seed=1)
|
162 |
+
|
163 |
+
src_dataset = load_dataset("audiofolder", data_dir=audio_dir, split="train")
|
164 |
+
src_dataset = src_dataset.map(dataclean)
|
165 |
+
|
166 |
+
negel_152_dataset = load_dataset("audiofolder", data_dir=negel_152, split="train")
|
167 |
+
negel_152_dataset = negel_152_dataset.map(dataclean)
|
168 |
+
|
169 |
+
healthy_test_dataset = load_dataset(
|
170 |
+
"audiofolder", data_dir=healthy_dir, split='train')
|
171 |
+
healthy_test_dataset = healthy_test_dataset.map(dataclean)
|
172 |
+
|
173 |
+
Fary_PAL_test_dataset = load_dataset(
|
174 |
+
"audiofolder", data_dir=Fary_PAL_30, split='train')
|
175 |
+
Fary_PAL_test_dataset = Fary_PAL_test_dataset.map(dataclean)
|
176 |
+
|
177 |
+
John_p326_test_dataset = load_dataset(
|
178 |
+
"audiofolder", data_dir=John_p326, split='train')
|
179 |
+
John_p326_test_dataset = John_p326_test_dataset.map(dataclean)
|
180 |
+
|
181 |
+
John_video_test_dataset = load_dataset(
|
182 |
+
"audiofolder", data_dir=John_video, split='train')
|
183 |
+
John_video_test_dataset = John_video_test_dataset.map(dataclean)
|
184 |
+
|
185 |
+
patient_T_test_dataset = load_dataset("audiofolder", data_dir=patient_T, split='train')
|
186 |
+
patient_T_test_dataset = patient_T_test_dataset.map(dataclean)
|
187 |
+
|
188 |
+
patient_L_test_dataset = load_dataset("audiofolder", data_dir=patient_L, split='train')
|
189 |
+
patient_L_test_dataset = patient_L_test_dataset.map(dataclean)
|
190 |
+
pdb.set_trace()
|
191 |
+
|
192 |
+
# train_dev / test
|
193 |
+
ds = src_dataset.train_test_split(test_size=0.1, seed=1)
|
194 |
+
|
195 |
+
dataset_libri = load_dataset(
|
196 |
+
"hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
197 |
+
|
198 |
+
train_dev = ds['train']
|
199 |
+
# train / dev
|
200 |
+
train_dev = train_dev.train_test_split(
|
201 |
+
test_size=int(len(src_dataset)*0.1), seed=1)
|
202 |
+
# train/dev/test
|
203 |
+
train = train_dev['train']
|
204 |
+
test = ds['test']
|
205 |
+
dev = train_dev['test']
|
206 |
+
|
207 |
+
# # pdb.set_trace()
|
208 |
+
encoded_train = train.map(prepare_dataset, num_proc=4)
|
209 |
+
encoded_dev = dev.map(prepare_dataset, num_proc=4)
|
210 |
+
encoded_test = test.map(prepare_dataset, num_proc=4)
|
211 |
+
|
212 |
+
encoded_Tony_test = P1tony_test.map(prepare_dataset, num_proc=4)
|
213 |
+
encoded_healthy = healthy_test_dataset.map(prepare_dataset, num_proc=4)
|
214 |
+
encoded_Fary = Fary_PAL_test_dataset.map(prepare_dataset, num_proc=4)
|
215 |
+
encoded_John_p326 = John_p326_test_dataset.map(prepare_dataset, num_proc=4)
|
216 |
+
encoded_John_video = John_video_test_dataset.map(prepare_dataset, num_proc=4)
|
217 |
+
|
218 |
+
# pdb.set_trace()
|
219 |
+
|
220 |
+
WER = evaluate.load("wer")
|
221 |
+
|
222 |
+
# Whisper decoding
|
223 |
+
|
224 |
+
processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
|
225 |
+
model = WhisperForConditionalGeneration.from_pretrained(
|
226 |
+
"openai/whisper-medium").to("cuda:0")
|
227 |
+
tokenizer = WhisperTokenizer.from_pretrained(
|
228 |
+
"openai/whisper-medium", language="English", task="transcribe")
|
229 |
+
|
230 |
+
# Need to push tokenizer to hugginface/model to activate online API
|
231 |
+
|
232 |
+
# tokenizer.push_to_hub("KevinGeng/whipser_medium_en_PAL300_step25")
|
233 |
+
# import pdb
|
234 |
+
# pdb.set_trace()
|
235 |
+
|
236 |
+
feature_extractor = WhisperFeatureExtractor.from_pretrained(
|
237 |
+
"openai/whisper-medium")
|
238 |
+
|
239 |
+
|
240 |
+
def whisper_prepare_dataset(batch):
|
241 |
+
# load and resample audio data from 48 to 16kHz
|
242 |
+
audio = batch["audio"]
|
243 |
+
|
244 |
+
# compute log-Mel input features from input audio array
|
245 |
+
batch["input_features"] = feature_extractor(
|
246 |
+
audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
|
247 |
+
|
248 |
+
# encode target text to label ids
|
249 |
+
batch["labels"] = tokenizer(batch["transcription"]).input_ids
|
250 |
+
return batch
|
251 |
+
|
252 |
+
|
253 |
+
torch.cuda.empty_cache()
|
254 |
+
|
255 |
+
training_args = Seq2SeqTrainingArguments(
|
256 |
+
# change to a repo name of your choice
|
257 |
+
output_dir="./whisper-medium-PAL128-25step",
|
258 |
+
per_device_train_batch_size=8,
|
259 |
+
gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
|
260 |
+
learning_rate=1e-5,
|
261 |
+
warmup_steps=100,
|
262 |
+
max_steps=1000,
|
263 |
+
gradient_checkpointing=True,
|
264 |
+
fp16=True,
|
265 |
+
evaluation_strategy="steps",
|
266 |
+
per_device_eval_batch_size=8,
|
267 |
+
predict_with_generate=True,
|
268 |
+
generation_max_length=512,
|
269 |
+
save_steps=100,
|
270 |
+
eval_steps=25,
|
271 |
+
logging_steps=100,
|
272 |
+
report_to=["tensorboard"],
|
273 |
+
load_best_model_at_end=True,
|
274 |
+
metric_for_best_model="wer",
|
275 |
+
greater_is_better=False,
|
276 |
+
push_to_hub=True,
|
277 |
+
)
|
278 |
+
|
279 |
+
|
280 |
+
def my_map_to_pred(batch):
|
281 |
+
# pdb.set_trace()
|
282 |
+
audio = batch["audio"]
|
283 |
+
input_features = processor(
|
284 |
+
audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt").input_features
|
285 |
+
# batch["reference"] = whisper_processor.tokenizer._normalize(batch['text'])
|
286 |
+
batch["reference"] = processor.tokenizer._normalize(batch['transcription'])
|
287 |
+
|
288 |
+
with torch.no_grad():
|
289 |
+
# predicted_ids = whisper_model.generate(input_features.to("cuda"))[0]
|
290 |
+
predicted_ids = model.generate(input_features.to("cuda"))[0]
|
291 |
+
transcription = model.decode(predicted_ids)
|
292 |
+
batch["prediction"] = model.tokenizer._normalize(transcription)
|
293 |
+
return batch
|
294 |
+
|
295 |
+
|
296 |
+
@dataclass
|
297 |
+
class DataCollatorSpeechSeq2SeqWithPadding:
|
298 |
+
processor: Any
|
299 |
+
|
300 |
+
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
301 |
+
# split inputs and labels since they have to be of different lengths and need different padding methods
|
302 |
+
# first treat the audio inputs by simply returning torch tensors
|
303 |
+
input_features = [{"input_features": feature["input_features"]}
|
304 |
+
for feature in features]
|
305 |
+
batch = self.processor.feature_extractor.pad(
|
306 |
+
input_features, return_tensors="pt")
|
307 |
+
|
308 |
+
# get the tokenized label sequences
|
309 |
+
label_features = [{"input_ids": feature["labels"]}
|
310 |
+
for feature in features]
|
311 |
+
# pad the labels to max length
|
312 |
+
labels_batch = self.processor.tokenizer.pad(
|
313 |
+
label_features, return_tensors="pt")
|
314 |
+
|
315 |
+
# replace padding with -100 to ignore loss correctly
|
316 |
+
labels = labels_batch["input_ids"].masked_fill(
|
317 |
+
labels_batch.attention_mask.ne(1), -100)
|
318 |
+
|
319 |
+
# if bos token is appended in previous tokenization step,
|
320 |
+
# cut bos token here as it's append later anyways
|
321 |
+
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
|
322 |
+
labels = labels[:, 1:]
|
323 |
+
|
324 |
+
batch["labels"] = labels
|
325 |
+
|
326 |
+
return batch
|
327 |
+
|
328 |
+
|
329 |
+
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
|
330 |
+
|
331 |
+
|
332 |
+
def compute_metrics(pred):
|
333 |
+
pdb.set_trace()
|
334 |
+
pred_ids = pred.predictions
|
335 |
+
label_ids = pred.label_ids
|
336 |
+
|
337 |
+
# replace -100 with the pad_token_id
|
338 |
+
label_ids[label_ids == -100] = tokenizer.pad_token_id
|
339 |
+
|
340 |
+
# we do not want to group tokens when computing the metrics
|
341 |
+
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
342 |
+
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
|
343 |
+
|
344 |
+
wer = 100 * WER.compute(predictions=pred_str, references=label_str)
|
345 |
+
|
346 |
+
return {"wer": wer}
|
347 |
+
|
348 |
+
# encode_negel_79_train = Negel_79_train.map(whisper_prepare_dataset, num_proc=4)
|
349 |
+
# encode_negel_79_dev = Negel_79_dev.map(whisper_prepare_dataset, num_proc=4)
|
350 |
+
# encode_negel_79_test = Negel_79_test.map(whisper_prepare_dataset, num_proc=4)
|
351 |
+
whisper_test = test.map(whisper_prepare_dataset, num_proc=4)
|
352 |
+
|
353 |
+
encoded_Michael_52_train = Michael_52_train.map(whisper_prepare_dataset, num_proc=4)
|
354 |
+
encoded_Michael_52_dev = Michael_52_dev.map(whisper_prepare_dataset, num_proc=4)
|
355 |
+
encoded_Michael_52_test = Michael_52_test.map(whisper_prepare_dataset, num_proc=4)
|
356 |
+
# negel_152_train, negel_152_dev, negel_152_test = train_dev_test_split(negel_152_dataset, dev_rate=0.1, test_rate=0.1, seed=1)
|
357 |
+
# encoded_negel_152_test = negel_152_test.map(whisper_prepare_dataset, num_proc=4)
|
358 |
+
pdb.set_trace()
|
359 |
+
torch.cuda.empty_cache()
|
360 |
+
|
361 |
+
torch.cuda.empty_cache()
|
362 |
+
|
363 |
+
fine_tuned_model = WhisperForConditionalGeneration.from_pretrained(
|
364 |
+
fine_tuning_dir
|
365 |
+
).to("cuda")
|
366 |
+
# "fine_tuned/SSD/model/whipser_medium_TEP_patient_T"
|
367 |
+
# "./fine_tuned/whipser_medium_en_PAL300_step25_step2_VCTK/checkpoint-400"
|
368 |
+
#"./fine_tuned/whipser_medium_en_PAL300_step25_step2_VCTK/checkpoint-200"
|
369 |
+
|
370 |
+
|
371 |
+
def fine_tuned_map_to_pred(batch):
|
372 |
+
# pdb.set_trace()
|
373 |
+
audio = batch["audio"]
|
374 |
+
input_features = processor(
|
375 |
+
audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt").input_features
|
376 |
+
# batch["reference"] = whisper_processor.tokenizer._normalize(batch['text'])
|
377 |
+
batch["reference"] = processor.tokenizer._normalize(batch['transcription'])
|
378 |
+
|
379 |
+
with torch.no_grad():
|
380 |
+
# predicted_ids = whisper_model.generate(input_features.to("cuda"))[0]
|
381 |
+
predicted_ids = fine_tuned_model.generate(input_features.to("cuda"))[0]
|
382 |
+
transcription = tokenizer.decode(predicted_ids)
|
383 |
+
batch["prediction"] = tokenizer._normalize(transcription)
|
384 |
+
return batch
|
385 |
+
|
386 |
+
|
387 |
+
# output_dir="./fine_tuned/whipser_medium_en_PAL300_step25_step2_VCTK/checkpoint-400",
|
388 |
+
# output_dir="fine_tuned/SSD/model/whipser_medium_TEP_patient_TL_TL",
|
389 |
+
testing_args = Seq2SeqTrainingArguments(
|
390 |
+
# change to a repo name of your choice
|
391 |
+
output_dir=fine_tuning_dir,
|
392 |
+
per_device_train_batch_size=8,
|
393 |
+
gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
|
394 |
+
learning_rate=1e-5,
|
395 |
+
warmup_steps=100,
|
396 |
+
max_steps=1000,
|
397 |
+
gradient_checkpointing=True,
|
398 |
+
fp16=True,
|
399 |
+
evaluation_strategy="steps",
|
400 |
+
per_device_eval_batch_size=8,
|
401 |
+
predict_with_generate=True,
|
402 |
+
generation_max_length=512,
|
403 |
+
save_steps=100,
|
404 |
+
eval_steps=25,
|
405 |
+
logging_steps=100,
|
406 |
+
report_to=["tensorboard"],
|
407 |
+
load_best_model_at_end=True,
|
408 |
+
metric_for_best_model="wer",
|
409 |
+
greater_is_better=False,
|
410 |
+
push_to_hub=False,
|
411 |
+
)
|
412 |
+
|
413 |
+
predict_trainer = Seq2SeqTrainer(
|
414 |
+
args=testing_args,
|
415 |
+
model=fine_tuned_model,
|
416 |
+
data_collator=data_collator,
|
417 |
+
compute_metrics=compute_metrics,
|
418 |
+
tokenizer=processor.feature_extractor,
|
419 |
+
)
|
420 |
+
|
421 |
+
# trainer.train()
|
422 |
+
# fine tuned
|
423 |
+
# z_result = encoded_test.map(fine_tuned_map_to_pred)
|
424 |
+
|
425 |
+
pdb.set_trace()
|
426 |
+
encoded_Michael_test_result = encoded_Michael_52_test.map(fine_tuned_map_to_pred)
|
427 |
+
z_M = WER.compute(references=encoded_Michael_test_result['reference'], predictions=encoded_Michael_test_result['prediction'])
|
428 |
+
pdb.set_trace()
|
429 |
+
encoded_Tony_test_result = encoded_Tony_test.map(fine_tuned_map_to_pred)
|
430 |
+
z = WER.compute(references=encoded_Tony_test_result['reference'], predictions=encoded_Tony_test_result['prediction'])
|
431 |
+
pdb.set_trace()
|
432 |
+
|
433 |
+
z_result= test.map(fine_tuned_map_to_pred)
|
434 |
+
# 0.4692737430167598
|
435 |
+
z = WER.compute(references=z_result['reference'], predictions=z_result['prediction'])
|
436 |
+
# pdb.set_trace()
|
437 |
+
z_hel_result = encoded_healthy.map(fine_tuned_map_to_pred)
|
438 |
+
z_hel = WER.compute(references=z_hel_result['reference'], predictions=z_hel_result['prediction'])
|
439 |
+
# 0.1591610117211598
|
440 |
+
|
441 |
+
# encoded_negel_152_test
|
442 |
+
# encoded_negel_test_result = encoded_negel_152_test.map(fine_tuned_map_to_pred)
|
443 |
+
# z_negel = WER.compute(references=encoded_negel_test_result['reference'], predictions=encoded_negel_test_result['prediction'])
|
444 |
+
|
445 |
+
pdb.set_trace()
|
446 |
+
z_fary_result = encoded_Fary.map(fine_tuned_map_to_pred)
|
447 |
+
z_far = WER.compute(references=z_fary_result['reference'], predictions=z_fary_result['prediction'])
|
448 |
+
# 0.1791044776119403
|
449 |
+
# z_patient_LT = encoded_patient_TL_test.map(fine_tuned_map_to_pred)
|
450 |
+
# z_patient_LT_result = WER.compute(references=z_patient_LT['reference'], predictions=z_patient_LT['prediction'])
|
451 |
+
# z_patient_L = encoded_patient_L_test.map(fine_tuned_map_to_pred)
|
452 |
+
# z_patient_L_result = WER.compute(references=z_patient_L['reference'], predictions=z_patient_L['prediction'])
|
453 |
+
# z_patient_T = encoded_patient_T_test.map(fine_tuned_map_to_pred)
|
454 |
+
# z_patient_T_result = WER.compute(references=z_patient_T['reference'], predictions=z_patient_T['prediction'])
|
455 |
+
|
456 |
+
z_john_p326_result = encoded_John_p326.map(fine_tuned_map_to_pred)
|
457 |
+
z_john_p326 = WER.compute(references=z_john_p326_result['reference'], predictions=z_john_p326_result['prediction'])
|
458 |
+
|
459 |
+
pdb.set_trace()
|
460 |
+
|
461 |
+
# # z_john_p326 = WER.compute(references=z_john_p326_result['reference'], predictions=z_john_p326_result['prediction'])
|
462 |
+
# # 0.4648241206030151
|
463 |
+
pdb.set_trace()
|
464 |
+
|
465 |
+
# # y_John_video= fine_tuned_trainer.predict(encoded_John_video)
|
466 |
+
# # metrics={'test_loss': 2.665189743041992, 'test_wer': 0.7222222222222222, 'test_runtime': 0.1633, 'test_samples_per_second': 48.979, 'test_steps_per_second': 6.122})
|
467 |
+
# pdb.set_trace()
|
468 |
+
|
469 |
+
# p326 training
|
470 |
+
# metrics={'test_loss': 0.4804028868675232, 'test_wer': 0.21787709497206703, 'test_runtime': 0.3594, 'test_samples_per_second': 44.517, 'test_steps_per_second': 5.565})
|
471 |
+
# hel metrics={'test_loss': 1.6363693475723267, 'test_wer': 0.17951881554595928, 'test_runtime': 3.8451, 'test_samples_per_second': 41.611, 'test_steps_per_second': 5.201})
|
472 |
+
# Fary: metrics={'t est_loss': 1.4633615016937256, 'test_wer': 0.5572139303482587, 'test_runtime': 0.6627, 'test_samples_per_second': 45.27, 'test_steps_per_second': 6.036})
|
473 |
+
# p326 large: metrics={'test_loss': 0.6568527817726135, 'test_wer': 0.2889447236180904, 'test_runtime': 0.7169, 'test_samples_per_second': 51.613, 'test_steps_per_second': 6.975})
|
packages.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
festival
|
2 |
+
espeak # or espeak-ng on Linux
|
requirements.txt
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==1.0.0
|
2 |
+
aiohttp==3.8.1
|
3 |
+
aiosignal==1.2.0
|
4 |
+
analytics-python==1.4.0
|
5 |
+
antlr4-python3-runtime==4.8
|
6 |
+
anyio==3.5.0
|
7 |
+
asgiref==3.5.0
|
8 |
+
async-timeout==4.0.2
|
9 |
+
attrs==21.4.0
|
10 |
+
backoff==1.10.0
|
11 |
+
bcrypt==3.2.0
|
12 |
+
bitarray==2.4.0
|
13 |
+
cachetools==5.0.0
|
14 |
+
certifi==2021.10.8
|
15 |
+
cffi==1.15.0
|
16 |
+
charset-normalizer==2.0.12
|
17 |
+
click==8.0.4
|
18 |
+
colorama==0.4.4
|
19 |
+
cryptography==36.0.1
|
20 |
+
cycler==0.11.0
|
21 |
+
Cython==0.29.28
|
22 |
+
fairseq @ git+https://github.com/pytorch/fairseq.git@d03f4e771484a433f025f47744017c2eb6e9c6bc
|
23 |
+
fastapi==0.75.0
|
24 |
+
ffmpy==0.3.0
|
25 |
+
fonttools==4.30.0
|
26 |
+
frozenlist==1.3.0
|
27 |
+
fsspec==2022.2.0
|
28 |
+
future==0.18.2
|
29 |
+
google-auth==2.6.0
|
30 |
+
google-auth-oauthlib==0.4.6
|
31 |
+
gradio==3.40.0
|
32 |
+
grpcio==1.44.0
|
33 |
+
h11==0.12.0
|
34 |
+
hydra-core==1.0.7
|
35 |
+
idna==3.3
|
36 |
+
importlib-metadata==4.11.3
|
37 |
+
Jinja2==3.0.3
|
38 |
+
kiwisolver==1.3.2
|
39 |
+
linkify-it-py==1.0.3
|
40 |
+
Markdown==3.3.6
|
41 |
+
markdown-it-py==2.0.1
|
42 |
+
MarkupSafe==2.1.0
|
43 |
+
matplotlib==3.5.1
|
44 |
+
mdit-py-plugins==0.3.0
|
45 |
+
mdurl==0.1.0
|
46 |
+
monotonic==1.6
|
47 |
+
multidict==6.0.2
|
48 |
+
numpy==1.22.3
|
49 |
+
oauthlib==3.2.0
|
50 |
+
omegaconf==2.0.6
|
51 |
+
orjson==3.6.7
|
52 |
+
packaging==21.3
|
53 |
+
pandas==1.4.1
|
54 |
+
paramiko==2.10.1
|
55 |
+
Pillow==9.0.1
|
56 |
+
portalocker==2.4.0
|
57 |
+
protobuf==3.19.4
|
58 |
+
pyasn1==0.4.8
|
59 |
+
pyasn1-modules==0.2.8
|
60 |
+
pycparser==2.21
|
61 |
+
pycryptodome==3.14.1
|
62 |
+
pydantic==1.9.0
|
63 |
+
pyDeprecate==0.3.1
|
64 |
+
pydub==0.25.1
|
65 |
+
PyNaCl==1.5.0
|
66 |
+
pyparsing==3.0.7
|
67 |
+
python-dateutil==2.8.2
|
68 |
+
python-multipart==0.0.5
|
69 |
+
pytorch-lightning==1.5.10
|
70 |
+
pytz==2021.3
|
71 |
+
PyYAML==6.0
|
72 |
+
regex==2022.3.2
|
73 |
+
requests==2.27.1
|
74 |
+
requests-oauthlib==1.3.1
|
75 |
+
rsa==4.8
|
76 |
+
sacrebleu==2.0.0
|
77 |
+
six==1.16.0
|
78 |
+
sniffio==1.2.0
|
79 |
+
starlette==0.17.1
|
80 |
+
tabulate==0.8.9
|
81 |
+
tensorboard==2.8.0
|
82 |
+
tensorboard-data-server==0.6.1
|
83 |
+
tensorboard-plugin-wit==1.8.1
|
84 |
+
torch==1.12.1
|
85 |
+
torchaudio==0.12.1
|
86 |
+
torchmetrics==0.7.2
|
87 |
+
tqdm==4.63.0
|
88 |
+
typing-extensions==4.1.1
|
89 |
+
uc-micro-py==1.0.1
|
90 |
+
urllib3==1.26.8
|
91 |
+
uvicorn==0.17.6
|
92 |
+
Werkzeug==2.0.3
|
93 |
+
yarl==1.7.2
|
94 |
+
zipp==3.7.0
|
95 |
+
|
96 |
+
transformers
|
97 |
+
deepspeech
|
98 |
+
tensorboardX
|
99 |
+
jiwer
|
100 |
+
phonemizer
|
101 |
+
librosa
|
102 |
+
|
103 |
+
rich
|
src/description.html
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<p>This is the experiment page for Laronix Data Recording.<br>
|
2 |
+
<br>
|
3 |
+
1. Select one example from below, a sound file, with its reference transcription and its speaking rate will be loaded as inputs.<br>
|
4 |
+
You can check the sound file first and prepare for reading the transcription at a similar tempo.<br>
|
5 |
+
2. Delete the sound file (click the X button on the right), a recording button will appear.<br>
|
6 |
+
3. Click the recording button to start, click again to stop. Make sure you are not mispronouncing or including any detectable noises.<br>
|
7 |
+
4. Click "Submit" button and wait for the result.<br>
|
8 |
+
5. Please check the message box to see the feedback, if ERROR appears, delete your previous recording and try again :).<br>
|
9 |
+
6. If "GOOD JOB!" message appears, click "Flag as Perfect" and start another recording.<br>
|
10 |
+
7. If you try several times (N >= 10) and still can not clear the mission, you can flag your best recording by clicking "Doubtful Speaking Rate" or "Doubtful Naturalness". <br>
|
11 |
+
Yet this seldom happens, so please try to meet the system's requirement first!<br>
|
12 |
+
8. If you have any other question, Please contact kevin@laronix.com </p>
|
13 |
+
<img src="https://static.wixstatic.com/media/e7e144_93e98148d06147828031797eb4525b80~mv2.png/v1/crop/x_0,y_25,w_2606,h_882/fill/w_396,h_142,al_c,q_85,usm_0.66_1.00_0.01,enc_auto/newlogo.png" align="right" height="20%" width="20%">
|
src/lightning_module.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append("src")
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import os
|
7 |
+
import numpy as np
|
8 |
+
import hydra
|
9 |
+
from model import load_ssl_model, PhonemeEncoder, DomainEmbedding, LDConditioner, Projection
|
10 |
+
|
11 |
+
|
12 |
+
class BaselineLightningModule(pl.LightningModule):
|
13 |
+
def __init__(self, cfg):
|
14 |
+
super().__init__()
|
15 |
+
self.cfg = cfg
|
16 |
+
self.construct_model()
|
17 |
+
self.save_hyperparameters()
|
18 |
+
|
19 |
+
def construct_model(self):
|
20 |
+
self.feature_extractors = nn.ModuleList([
|
21 |
+
load_ssl_model(cp_path='src/wav2vec_small.pt'),
|
22 |
+
DomainEmbedding(3,128),
|
23 |
+
])
|
24 |
+
output_dim = sum([ feature_extractor.get_output_dim() for feature_extractor in self.feature_extractors])
|
25 |
+
output_layers = [
|
26 |
+
LDConditioner(judge_dim=128,num_judges=3000,input_dim=output_dim)
|
27 |
+
]
|
28 |
+
output_dim = output_layers[-1].get_output_dim()
|
29 |
+
output_layers.append(
|
30 |
+
Projection(hidden_dim=2048,activation=torch.nn.ReLU(),range_clipping=False,input_dim=output_dim)
|
31 |
+
|
32 |
+
)
|
33 |
+
|
34 |
+
self.output_layers = nn.ModuleList(output_layers)
|
35 |
+
|
36 |
+
def forward(self, inputs):
|
37 |
+
outputs = {}
|
38 |
+
for feature_extractor in self.feature_extractors:
|
39 |
+
outputs.update(feature_extractor(inputs))
|
40 |
+
x = outputs
|
41 |
+
for output_layer in self.output_layers:
|
42 |
+
x = output_layer(x,inputs)
|
43 |
+
return x
|
src/model.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import fairseq
|
4 |
+
import os
|
5 |
+
import hydra
|
6 |
+
|
7 |
+
def load_ssl_model(cp_path):
|
8 |
+
ssl_model_type = cp_path.split("/")[-1]
|
9 |
+
wavlm = "WavLM" in ssl_model_type
|
10 |
+
if wavlm:
|
11 |
+
checkpoint = torch.load(cp_path)
|
12 |
+
cfg = WavLMConfig(checkpoint['cfg'])
|
13 |
+
ssl_model = WavLM(cfg)
|
14 |
+
ssl_model.load_state_dict(checkpoint['model'])
|
15 |
+
if 'Large' in ssl_model_type:
|
16 |
+
SSL_OUT_DIM = 1024
|
17 |
+
else:
|
18 |
+
SSL_OUT_DIM = 768
|
19 |
+
else:
|
20 |
+
if ssl_model_type == "wav2vec_small.pt":
|
21 |
+
SSL_OUT_DIM = 768
|
22 |
+
elif ssl_model_type in ["w2v_large_lv_fsh_swbd_cv.pt", "xlsr_53_56k.pt"]:
|
23 |
+
SSL_OUT_DIM = 1024
|
24 |
+
else:
|
25 |
+
print("*** ERROR *** SSL model type " + ssl_model_type + " not supported.")
|
26 |
+
exit()
|
27 |
+
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
28 |
+
[cp_path]
|
29 |
+
)
|
30 |
+
ssl_model = model[0]
|
31 |
+
ssl_model.remove_pretraining_modules()
|
32 |
+
return SSL_model(ssl_model, SSL_OUT_DIM, wavlm)
|
33 |
+
|
34 |
+
class SSL_model(nn.Module):
|
35 |
+
def __init__(self,ssl_model,ssl_out_dim,wavlm) -> None:
|
36 |
+
super(SSL_model,self).__init__()
|
37 |
+
self.ssl_model, self.ssl_out_dim = ssl_model, ssl_out_dim
|
38 |
+
self.WavLM = wavlm
|
39 |
+
|
40 |
+
def forward(self,batch):
|
41 |
+
wav = batch['wav']
|
42 |
+
wav = wav.squeeze(1) # [batches, audio_len]
|
43 |
+
if self.WavLM:
|
44 |
+
x = self.ssl_model.extract_features(wav)[0]
|
45 |
+
else:
|
46 |
+
res = self.ssl_model(wav, mask=False, features_only=True)
|
47 |
+
x = res["x"]
|
48 |
+
return {"ssl-feature":x}
|
49 |
+
def get_output_dim(self):
|
50 |
+
return self.ssl_out_dim
|
51 |
+
|
52 |
+
|
53 |
+
class PhonemeEncoder(nn.Module):
|
54 |
+
'''
|
55 |
+
PhonemeEncoder consists of an embedding layer, an LSTM layer, and a linear layer.
|
56 |
+
Args:
|
57 |
+
vocab_size: the size of the vocabulary
|
58 |
+
hidden_dim: the size of the hidden state of the LSTM
|
59 |
+
emb_dim: the size of the embedding layer
|
60 |
+
out_dim: the size of the output of the linear layer
|
61 |
+
n_lstm_layers: the number of LSTM layers
|
62 |
+
'''
|
63 |
+
def __init__(self, vocab_size, hidden_dim, emb_dim, out_dim,n_lstm_layers,with_reference=True) -> None:
|
64 |
+
super().__init__()
|
65 |
+
self.with_reference = with_reference
|
66 |
+
self.embedding = nn.Embedding(vocab_size, emb_dim)
|
67 |
+
self.encoder = nn.LSTM(emb_dim, hidden_dim,
|
68 |
+
num_layers=n_lstm_layers, dropout=0.1, bidirectional=True)
|
69 |
+
self.linear = nn.Sequential(
|
70 |
+
nn.Linear(hidden_dim + hidden_dim*self.with_reference, out_dim),
|
71 |
+
nn.ReLU()
|
72 |
+
)
|
73 |
+
self.out_dim = out_dim
|
74 |
+
|
75 |
+
def forward(self,batch):
|
76 |
+
seq = batch['phonemes']
|
77 |
+
lens = batch['phoneme_lens']
|
78 |
+
reference_seq = batch['reference']
|
79 |
+
reference_lens = batch['reference_lens']
|
80 |
+
emb = self.embedding(seq)
|
81 |
+
emb = torch.nn.utils.rnn.pack_padded_sequence(
|
82 |
+
emb, lens, batch_first=True, enforce_sorted=False)
|
83 |
+
_, (ht, _) = self.encoder(emb)
|
84 |
+
feature = ht[-1] + ht[0]
|
85 |
+
if self.with_reference:
|
86 |
+
if reference_seq==None or reference_lens ==None:
|
87 |
+
raise ValueError("reference_batch and reference_lens should not be None when with_reference is True")
|
88 |
+
reference_emb = self.embedding(reference_seq)
|
89 |
+
reference_emb = torch.nn.utils.rnn.pack_padded_sequence(
|
90 |
+
reference_emb, reference_lens, batch_first=True, enforce_sorted=False)
|
91 |
+
_, (ht_ref, _) = self.encoder(emb)
|
92 |
+
reference_feature = ht_ref[-1] + ht_ref[0]
|
93 |
+
feature = self.linear(torch.cat([feature,reference_feature],1))
|
94 |
+
else:
|
95 |
+
feature = self.linear(feature)
|
96 |
+
return {"phoneme-feature": feature}
|
97 |
+
def get_output_dim(self):
|
98 |
+
return self.out_dim
|
99 |
+
|
100 |
+
class DomainEmbedding(nn.Module):
|
101 |
+
def __init__(self,n_domains,domain_dim) -> None:
|
102 |
+
super().__init__()
|
103 |
+
self.embedding = nn.Embedding(n_domains,domain_dim)
|
104 |
+
self.output_dim = domain_dim
|
105 |
+
def forward(self, batch):
|
106 |
+
return {"domain-feature": self.embedding(batch['domains'])}
|
107 |
+
def get_output_dim(self):
|
108 |
+
return self.output_dim
|
109 |
+
|
110 |
+
|
111 |
+
class LDConditioner(nn.Module):
|
112 |
+
'''
|
113 |
+
Conditions ssl output by listener embedding
|
114 |
+
'''
|
115 |
+
def __init__(self,input_dim, judge_dim, num_judges=None):
|
116 |
+
super().__init__()
|
117 |
+
self.input_dim = input_dim
|
118 |
+
self.judge_dim = judge_dim
|
119 |
+
self.num_judges = num_judges
|
120 |
+
assert num_judges !=None
|
121 |
+
self.judge_embedding = nn.Embedding(num_judges, self.judge_dim)
|
122 |
+
# concat [self.output_layer, phoneme features]
|
123 |
+
|
124 |
+
self.decoder_rnn = nn.LSTM(
|
125 |
+
input_size = self.input_dim + self.judge_dim,
|
126 |
+
hidden_size = 512,
|
127 |
+
num_layers = 1,
|
128 |
+
batch_first = True,
|
129 |
+
bidirectional = True
|
130 |
+
) # linear?
|
131 |
+
self.out_dim = self.decoder_rnn.hidden_size*2
|
132 |
+
|
133 |
+
def get_output_dim(self):
|
134 |
+
return self.out_dim
|
135 |
+
|
136 |
+
|
137 |
+
def forward(self, x, batch):
|
138 |
+
judge_ids = batch['judge_id']
|
139 |
+
if 'phoneme-feature' in x.keys():
|
140 |
+
concatenated_feature = torch.cat((x['ssl-feature'], x['phoneme-feature'].unsqueeze(1).expand(-1,x['ssl-feature'].size(1) ,-1)),dim=2)
|
141 |
+
else:
|
142 |
+
concatenated_feature = x['ssl-feature']
|
143 |
+
if 'domain-feature' in x.keys():
|
144 |
+
concatenated_feature = torch.cat(
|
145 |
+
(
|
146 |
+
concatenated_feature,
|
147 |
+
x['domain-feature']
|
148 |
+
.unsqueeze(1)
|
149 |
+
.expand(-1, concatenated_feature.size(1), -1),
|
150 |
+
),
|
151 |
+
dim=2,
|
152 |
+
)
|
153 |
+
if judge_ids != None:
|
154 |
+
concatenated_feature = torch.cat(
|
155 |
+
(
|
156 |
+
concatenated_feature,
|
157 |
+
self.judge_embedding(judge_ids)
|
158 |
+
.unsqueeze(1)
|
159 |
+
.expand(-1, concatenated_feature.size(1), -1),
|
160 |
+
),
|
161 |
+
dim=2,
|
162 |
+
)
|
163 |
+
decoder_output, (h, c) = self.decoder_rnn(concatenated_feature)
|
164 |
+
return decoder_output
|
165 |
+
|
166 |
+
class Projection(nn.Module):
|
167 |
+
def __init__(self, input_dim, hidden_dim, activation, range_clipping=False):
|
168 |
+
super(Projection, self).__init__()
|
169 |
+
self.range_clipping = range_clipping
|
170 |
+
output_dim = 1
|
171 |
+
if range_clipping:
|
172 |
+
self.proj = nn.Tanh()
|
173 |
+
|
174 |
+
self.net = nn.Sequential(
|
175 |
+
nn.Linear(input_dim, hidden_dim),
|
176 |
+
activation,
|
177 |
+
nn.Dropout(0.3),
|
178 |
+
nn.Linear(hidden_dim, output_dim),
|
179 |
+
)
|
180 |
+
self.output_dim = output_dim
|
181 |
+
|
182 |
+
def forward(self, x, batch):
|
183 |
+
output = self.net(x)
|
184 |
+
|
185 |
+
# range clipping
|
186 |
+
if self.range_clipping:
|
187 |
+
return self.proj(output) * 2.0 + 3
|
188 |
+
else:
|
189 |
+
return output
|
190 |
+
def get_output_dim(self):
|
191 |
+
return self.output_dim
|