KevinGeng commited on
Commit
a1fe393
1 Parent(s): b67bc9d

push to HF

Browse files
.gitattributes CHANGED
@@ -2,13 +2,11 @@
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
14
  *.npy filter=lfs diff=lfs merge=lfs -text
@@ -16,16 +14,14 @@
16
  *.onnx filter=lfs diff=lfs merge=lfs -text
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
  *.pickle filter=lfs diff=lfs merge=lfs -text
21
  *.pkl filter=lfs diff=lfs merge=lfs -text
 
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
  *.wasm filter=lfs diff=lfs merge=lfs -text
@@ -33,3 +29,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
 
5
  *.ftz filter=lfs diff=lfs merge=lfs -text
6
  *.gz filter=lfs diff=lfs merge=lfs -text
7
  *.h5 filter=lfs diff=lfs merge=lfs -text
8
  *.joblib filter=lfs diff=lfs merge=lfs -text
9
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
 
10
  *.model filter=lfs diff=lfs merge=lfs -text
11
  *.msgpack filter=lfs diff=lfs merge=lfs -text
12
  *.npy filter=lfs diff=lfs merge=lfs -text
 
14
  *.onnx filter=lfs diff=lfs merge=lfs -text
15
  *.ot filter=lfs diff=lfs merge=lfs -text
16
  *.parquet filter=lfs diff=lfs merge=lfs -text
 
17
  *.pickle filter=lfs diff=lfs merge=lfs -text
18
  *.pkl filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
  *.pt filter=lfs diff=lfs merge=lfs -text
21
  *.pth filter=lfs diff=lfs merge=lfs -text
22
  *.rar filter=lfs diff=lfs merge=lfs -text
 
23
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
25
  *.tflite filter=lfs diff=lfs merge=lfs -text
26
  *.tgz filter=lfs diff=lfs merge=lfs -text
27
  *.wasm filter=lfs diff=lfs merge=lfs -text
 
29
  *.zip filter=lfs diff=lfs merge=lfs -text
30
  *.zst filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
32
+ src/epoch=3-step=7459.ckpt filter=lfs diff=lfs merge=lfs -text
33
+ src/wav2vec_small.pt filter=lfs diff=lfs merge=lfs -text
34
+ data/p326_split filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # generated by: https://github.com/michaelliao/gitignore-online-generator
2
+
3
+ #################### Python.gitignore ####################
4
+
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+ cover/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ .pybuilder/
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ # For a library or package, you might want to ignore these files since the code is
91
+ # intended to run in multiple environments; otherwise, check them in:
92
+ # .python-version
93
+
94
+ # pipenv
95
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
97
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
98
+ # install all needed dependencies.
99
+ #Pipfile.lock
100
+
101
+ # poetry
102
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
104
+ # commonly ignored for libraries.
105
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106
+ #poetry.lock
107
+
108
+ # pdm
109
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110
+ #pdm.lock
111
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112
+ # in version control.
113
+ # https://pdm.fming.dev/#use-with-ide
114
+ .pdm.toml
115
+
116
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117
+ __pypackages__/
118
+
119
+ # Celery stuff
120
+ celerybeat-schedule
121
+ celerybeat.pid
122
+
123
+ # SageMath parsed files
124
+ *.sage.py
125
+
126
+ # Environments
127
+ .env
128
+ .venv
129
+ env/
130
+ venv/
131
+ ENV/
132
+ env.bak/
133
+ venv.bak/
134
+
135
+ # Spyder project settings
136
+ .spyderproject
137
+ .spyproject
138
+
139
+ # Rope project settings
140
+ .ropeproject
141
+
142
+ # mkdocs documentation
143
+ /site
144
+
145
+ # mypy
146
+ .mypy_cache/
147
+ .dmypy.json
148
+ dmypy.json
149
+
150
+ # Pyre type checker
151
+ .pyre/
152
+
153
+ # pytype static type analyzer
154
+ .pytype/
155
+
156
+ # Cython debug symbols
157
+ cython_debug/
158
+
159
+ # PyCharm
160
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
163
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164
+ #.idea/
165
+
166
+ #################### Archives.gitignore ####################
167
+
168
+ # It's better to unpack these files and commit the raw source because
169
+ # git has its own built in compression methods.
170
+ *.7z
171
+ *.jar
172
+ *.rar
173
+ *.zip
174
+ *.gz
175
+ *.gzip
176
+ *.tgz
177
+ *.bzip
178
+ *.bzip2
179
+ *.bz2
180
+ *.xz
181
+ *.lzma
182
+ *.cab
183
+ *.xar
184
+
185
+ # Packing-only formats
186
+ *.iso
187
+ *.tar
188
+
189
+ # Package management formats
190
+ *.dmg
191
+ *.xpi
192
+ *.gem
193
+ *.egg
194
+ *.deb
195
+ *.rpm
196
+ *.msi
197
+ *.msm
198
+ *.msp
199
+ *.txz
200
+
201
+ #################### Backup.gitignore ####################
202
+
203
+ *.bak
204
+ *.gho
205
+ *.ori
206
+ *.orig
207
+ *.tmp
208
+
209
+ #################### Emacs.gitignore ####################
210
+
211
+ # -*- mode: gitignore; -*-
212
+ *~
213
+ \#*\#
214
+ /.emacs.desktop
215
+ /.emacs.desktop.lock
216
+ *.elc
217
+ auto-save-list
218
+ tramp
219
+ .\#*
220
+
221
+ # Org-mode
222
+ .org-id-locations
223
+ *_archive
224
+
225
+ # flymake-mode
226
+ *_flymake.*
227
+
228
+ # eshell files
229
+ /eshell/history
230
+ /eshell/lastdir
231
+
232
+ # elpa packages
233
+ /elpa/
234
+
235
+ # reftex files
236
+ *.rel
237
+
238
+ # AUCTeX auto folder
239
+ /auto/
240
+
241
+ # cask packages
242
+ .cask/
243
+ dist/
244
+
245
+ # Flycheck
246
+ flycheck_*.el
247
+
248
+ # server auth directory
249
+ /server/
250
+
251
+ # projectiles files
252
+ .projectile
253
+
254
+ # directory configuration
255
+ .dir-locals.el
256
+
257
+ # network security
258
+ /network-security.data
259
+
260
+
261
+ #################### JetBrains.gitignore ####################
262
+
263
+ # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
264
+ # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
265
+
266
+ # User-specific stuff
267
+ .idea/**/workspace.xml
268
+ .idea/**/tasks.xml
269
+ .idea/**/usage.statistics.xml
270
+ .idea/**/dictionaries
271
+ .idea/**/shelf
272
+
273
+ # AWS User-specific
274
+ .idea/**/aws.xml
275
+
276
+ # Generated files
277
+ .idea/**/contentModel.xml
278
+
279
+ # Sensitive or high-churn files
280
+ .idea/**/dataSources/
281
+ .idea/**/dataSources.ids
282
+ .idea/**/dataSources.local.xml
283
+ .idea/**/sqlDataSources.xml
284
+ .idea/**/dynamic.xml
285
+ .idea/**/uiDesigner.xml
286
+ .idea/**/dbnavigator.xml
287
+
288
+ # Gradle
289
+ .idea/**/gradle.xml
290
+ .idea/**/libraries
291
+
292
+ # Gradle and Maven with auto-import
293
+ # When using Gradle or Maven with auto-import, you should exclude module files,
294
+ # since they will be recreated, and may cause churn. Uncomment if using
295
+ # auto-import.
296
+ # .idea/artifacts
297
+ # .idea/compiler.xml
298
+ # .idea/jarRepositories.xml
299
+ # .idea/modules.xml
300
+ # .idea/*.iml
301
+ # .idea/modules
302
+ # *.iml
303
+ # *.ipr
304
+
305
+ # CMake
306
+ cmake-build-*/
307
+
308
+ # Mongo Explorer plugin
309
+ .idea/**/mongoSettings.xml
310
+
311
+ # File-based project format
312
+ *.iws
313
+
314
+ # IntelliJ
315
+ out/
316
+
317
+ # mpeltonen/sbt-idea plugin
318
+ .idea_modules/
319
+
320
+ # JIRA plugin
321
+ atlassian-ide-plugin.xml
322
+
323
+ # Cursive Clojure plugin
324
+ .idea/replstate.xml
325
+
326
+ # SonarLint plugin
327
+ .idea/sonarlint/
328
+
329
+ # Crashlytics plugin (for Android Studio and IntelliJ)
330
+ com_crashlytics_export_strings.xml
331
+ crashlytics.properties
332
+ crashlytics-build.properties
333
+ fabric.properties
334
+
335
+ # Editor-based Rest Client
336
+ .idea/httpRequests
337
+
338
+ # Android studio 3.1+ serialized cache file
339
+ .idea/caches/build_file_checksums.ser
340
+
341
+ #################### Linux.gitignore ####################
342
+
343
+ *~
344
+
345
+ # temporary files which can be created if a process still has a handle open of a deleted file
346
+ .fuse_hidden*
347
+
348
+ # KDE directory preferences
349
+ .directory
350
+
351
+ # Linux trash folder which might appear on any partition or disk
352
+ .Trash-*
353
+
354
+ # .nfs files are created when an open file is removed but is still being accessed
355
+ .nfs*
356
+
357
+ #################### NotepadPP.gitignore ####################
358
+
359
+ # Notepad++ backups #
360
+ *.bak
361
+
362
+ #################### PuTTY.gitignore ####################
363
+
364
+ # Private key
365
+ *.ppk
366
+
367
+ #################### SublimeText.gitignore ####################
368
+
369
+ # Cache files for Sublime Text
370
+ *.tmlanguage.cache
371
+ *.tmPreferences.cache
372
+ *.stTheme.cache
373
+
374
+ # Workspace files are user-specific
375
+ *.sublime-workspace
376
+
377
+ # Project files should be checked into the repository, unless a significant
378
+ # proportion of contributors will probably not be using Sublime Text
379
+ # *.sublime-project
380
+
381
+ # SFTP configuration file
382
+ sftp-config.json
383
+ sftp-config-alt*.json
384
+
385
+ # Package control specific files
386
+ Package Control.last-run
387
+ Package Control.ca-list
388
+ Package Control.ca-bundle
389
+ Package Control.system-ca-bundle
390
+ Package Control.cache/
391
+ Package Control.ca-certs/
392
+ Package Control.merged-ca-bundle
393
+ Package Control.user-ca-bundle
394
+ oscrypto-ca-bundle.crt
395
+ bh_unicode_properties.cache
396
+
397
+ # Sublime-github package stores a github token in this file
398
+ # https://packagecontrol.io/packages/sublime-github
399
+ GitHub.sublime-settings
400
+
401
+ #################### Vim.gitignore ####################
402
+
403
+ # Swap
404
+ [._]*.s[a-v][a-z]
405
+ !*.svg # comment out if you don't need vector files
406
+ [._]*.sw[a-p]
407
+ [._]s[a-rt-v][a-z]
408
+ [._]ss[a-gi-z]
409
+ [._]sw[a-p]
410
+
411
+ # Session
412
+ Session.vim
413
+ Sessionx.vim
414
+
415
+ # Temporary
416
+ .netrwhist
417
+ *~
418
+ # Auto-generated tag files
419
+ tags
420
+ # Persistent undo
421
+ [._]*.un~
422
+
423
+ #################### VirtualEnv.gitignore ####################
424
+
425
+ # Virtualenv
426
+ # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
427
+ .Python
428
+ [Bb]in
429
+ [Ii]nclude
430
+ [Ll]ib
431
+ [Ll]ib64
432
+ [Ll]ocal
433
+ [Ss]cripts
434
+ pyvenv.cfg
435
+ .venv
436
+ pip-selfcheck.json
437
+
438
+ #################### VisualStudioCode.gitignore ####################
439
+
440
+ .vscode/*
441
+ !.vscode/settings.json
442
+ !.vscode/tasks.json
443
+ !.vscode/launch.json
444
+ !.vscode/extensions.json
445
+ !.vscode/*.code-snippets
446
+
447
+ # Local History for Visual Studio Code
448
+ .history/
449
+
450
+ # Built Visual Studio Code Extensions
451
+ *.vsix
452
+
453
+ #################### Windows.gitignore ####################
454
+
455
+ # Windows thumbnail cache files
456
+ Thumbs.db
457
+ Thumbs.db:encryptable
458
+ ehthumbs.db
459
+ ehthumbs_vista.db
460
+
461
+ # Dump file
462
+ *.stackdump
463
+
464
+ # Folder config file
465
+ [Dd]esktop.ini
466
+
467
+ # Recycle Bin used on file shares
468
+ $RECYCLE.BIN/
469
+
470
+ # Windows Installer files
471
+ *.cab
472
+ *.msi
473
+ *.msix
474
+ *.msm
475
+ *.msp
476
+
477
+ # Windows shortcuts
478
+ *.lnk
479
+
480
+ #################### macOS.gitignore ####################
481
+
482
+ # General
483
+ .DS_Store
484
+ .AppleDouble
485
+ .LSOverride
486
+
487
+ # Icon must end with two \r
488
+ Icon
489
+
490
+
491
+ # Thumbnails
492
+ ._*
493
+
494
+ # Files that might appear in the root of a volume
495
+ .DocumentRevisions-V100
496
+ .fseventsd
497
+ .Spotlight-V100
498
+ .TemporaryItems
499
+ .Trashes
500
+ .VolumeIcon.icns
501
+ .com.apple.timemachine.donotpresent
502
+
503
+ # Directories potentially created on remote AFP share
504
+ .AppleDB
505
+ .AppleDesktop
506
+ Network Trash Folder
507
+ Temporary Items
508
+ .apdisk
509
+
510
+ #################### Custom.gitignore ####################
511
+
512
+ # add your custom gitignore here:
513
+ !.gitignore
514
+ !.gitsubmodules
515
+
516
+ # ignore data
517
+ data/
518
+ exp/
519
+ !src/lightning_module.py
520
+ *.wav
521
+ # ignore plots
522
+ *.png
523
+ # ignore csv
524
+ *.csv
525
+
526
+
527
+ !config/template.yaml
528
+ config
529
+ !local
530
+
531
+ ## Currently
532
+ src/wav2vec_small.pt
533
+ *.ckpt
534
+ *.bak
535
+
536
+ fine_tuned/
537
+
538
+ #
539
+ user/
540
+
541
+ .vscode
README.md CHANGED
@@ -1,12 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
- title: Laronix Recording
3
- emoji: 🐠
4
- colorFrom: gray
5
  colorTo: blue
6
  sdk: gradio
7
- sdk_version: 3.47.1
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Laronix Data Collection
2
+
3
+ This repository contains information about the Laronix data collection process, which involves collecting parallel data from AVA users. The dataset consists of two main sessions: scripted data and conversational data.
4
+
5
+ ## Dataset
6
+
7
+ The dataset is organized as follows:
8
+
9
+ ### 1. Scripted Data
10
+
11
+ The scripted data session includes 200 sentences collected from 5 articles. The references for both the audio and text versions of these sentences have already been uploaded or will be uploaded to the Laronix Recording system. (Ask [Kevin](kevin@laronix.com) for these files) The distribution of sentences from each article is as follows:
12
+
13
+ - Arthur the Rat: 56 sentences
14
+ - Cinder: 19 sentences
15
+ - Rainbow: 26 sentences
16
+ - Sentences: 59 sentences
17
+ - VCTK: 40 sentences
18
+
19
+ ### 2. Conversational Data
20
+
21
+ The conversational data session focuses on natural conversations and involves the following components:
22
+
23
+ #### a. Q&A
24
+
25
+ In this component, a set of 50 sentences will be provided, consisting of questions and answers. During the recording, the partner will ask the questions (Q), and the patient will provide the answers (A). Both the questions and answers will be recorded.
26
+
27
+ #### b. Freestyle
28
+
29
+ The patients will have the freedom to talk about a given topic. They will be asked to respond with 5 to 10 sentences. The structure for this component can be referenced from the [IELTS speaking test](https://www.ieltsbuddy.com/IELTS-speaking-questions-with-answers.html).
30
+
31
+ ## Data Inclusion Criteria
32
+
33
+ + No hearing loss or history of active cancer.
34
+ + 6 weeks of practice with AVA.
35
+
36
+ ## Document for Laronix Recording System
37
+
38
+ The Laronix recording system is designed for data collection from potential users of the AVA Device, which replaces their voice cord.
39
+
40
+ ### Input:
41
+
42
+ - Audio signal
43
+ - Reference ID
44
+ - Reference text
45
+ - Reference Phoneme per minute
46
+
47
+ ### Output:
48
+
49
+ - wav_pause_plot: Wave signal plot with pauses detected by VAD algorithm (SNR = 40dB)
50
+ - Predicted Mean Opinion Score: Score estimating data quality on the MOS scale using an ML prediction model (1-5)
51
+ - Hypotheses: Text predicted by Automatic Speech Recognition model (wav2vev2.0 + CTC)
52
+ - WER: Word Error Rate (lower is better)
53
+ - Predicted Phonemes
54
+ - PPM: Phonemes per minute
55
+ - Message: Feedback from the system
56
+
57
+ ## User Instruction
58
+
59
+ Please follow the instructions provided at the top of the APP page.
60
+
61
+ ```
62
+ - Laronix_AUTOMOS
63
+ - data
64
+ - Template
65
+ - ref_wav/
66
+ - 1.wav
67
+ - 2.wav
68
+ - ...
69
+ - ref_txt.txt
70
+ - ref.csv # audio prosody features reference <generate by script>
71
+ - exp
72
+ - Template
73
+ - Audio_to_evaluate # RAW WAV DATA
74
+ - log.csv # Recording log
75
+ - output # wav.file <generate by script>
76
+ - model
77
+ - epoch=3-step=7459.ckpt # MOS estimate model
78
+ - wav2vec_small.pt # WER model
79
+ - local
80
+ - get_ref_PPM.py # script for generating data/<ref_dir>/ref.csv
81
+ - post_processing.py # script for generating exp/<ref_dir>/output/*.wav
82
+ ```
83
+
84
  ---
85
+ title: Laronix Automos
86
+ emoji: 🏃
87
+ colorFrom: blue
88
  colorTo: blue
89
  sdk: gradio
90
+ sdk_version: 3.2
91
  app_file: app.py
92
  pinned: false
93
+ license: afl-3.0
94
  ---
95
 
96
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
97
+
98
+ # Laronix_AutoMOS
99
+
100
+ ## Usage:
101
+ ### Step 1: Prepare data and text
102
+ `<todo>`
103
+ ### Step 2: Preprocessing
104
+ ```
105
+ ## Generating *.csv, Voice/Unvoice Plot (optional) and config (optional)
106
+ python local/get_ref_PPM.py --ref_txt <ref_text> \
107
+ --ref_wavs <ref_wavs> \
108
+ --output_dir <output_dir> \
109
+ --to_config <True/False> \
110
+ --UV_flag <True/False> \
111
+ --UV_thre <UV_thre>}
112
+ ```
113
+ ### Step 3: Launch recording session:
114
+
115
+ ```
116
+ ## Start app.py
117
+ python app.py <config.yaml>
118
+ ```
119
+ + **Find logging below and lick URL to start**
120
+ ```
121
+ Launch examples
122
+ Running on local URL: http://127.0.0.1:7860/
123
+ ...
124
+ (Logs...)
125
+ ...
126
+ Running on public URL: https://87abe771e93229da.gradio.app
127
+ ```
app.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TODO:
3
+ + [x] Load Configuration
4
+ + [ ] Checking
5
+ + [ ] Better saving directory
6
+ """
7
+ import numpy as np
8
+ from pathlib import Path
9
+ import jiwer
10
+ import pdb
11
+ import torch.nn as nn
12
+ import torch
13
+ import torchaudio
14
+ import gradio as gr
15
+ from logging import PlaceHolder
16
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
17
+ import yaml
18
+ from transformers import pipeline
19
+ import librosa
20
+ import librosa.display
21
+ import matplotlib.pyplot as plt
22
+
23
+
24
+ # local import
25
+ import sys
26
+
27
+ sys.path.append("src")
28
+ import lightning_module
29
+
30
+ # Load automos
31
+ config_yaml = sys.argv[1]
32
+ with open(config_yaml, "r") as f:
33
+ # pdb.set_trace()
34
+ try:
35
+ config = yaml.safe_load(f)
36
+ except FileExistsError:
37
+ print("Config file Loading Error")
38
+ exit()
39
+
40
+ # Auto load examples
41
+
42
+ with open(config["ref_txt"], "r") as f:
43
+ refs = f.readlines()
44
+ refs_ids = [x.split()[0] for x in refs]
45
+ refs_txt = [" ".join(x.split()[1:]) for x in refs]
46
+ ref_feature = np.loadtxt(config["ref_feature"], delimiter=",", dtype="str")
47
+ ref_wavs = [str(x) for x in sorted(Path(config["ref_wavs"]).glob("**/*.wav"))]
48
+
49
+ dummy_wavs = [None for x in np.arange(len(ref_wavs))]
50
+
51
+ refs_ppm = np.array(ref_feature[:, -1][1:], dtype="str")
52
+
53
+ reference_id = gr.Textbox(value="ID", placeholder="Utter ID", label="Reference_ID")
54
+
55
+ reference_textbox = gr.Textbox(
56
+ value="Input reference here",
57
+ placeholder="Input reference here",
58
+ label="Reference",
59
+ )
60
+ reference_PPM = gr.Textbox(placeholder="Pneumatic Voice's PPM", label="Ref PPM")
61
+
62
+ # Set up interface
63
+ # remove dummpy wavs, ue the same ref_wavs for eval wavs
64
+ print("Preparing Examples")
65
+ examples = [
66
+ [w, w_, i, x, y] for w, w_, i, x, y in zip(ref_wavs, ref_wavs, refs_ids, refs_txt, refs_ppm)
67
+ ]
68
+
69
+ p = pipeline(
70
+ "automatic-speech-recognition",
71
+ model="KevinGeng/whipser_medium_en_PAL300_step25",
72
+ device=0,
73
+ )
74
+
75
+ # WER part
76
+ transformation = jiwer.Compose(
77
+ [
78
+ jiwer.RemovePunctuation(),
79
+ jiwer.ToLowerCase(),
80
+ jiwer.RemoveWhiteSpace(replace_by_space=True),
81
+ jiwer.RemoveMultipleSpaces(),
82
+ jiwer.ReduceToListOfListOfWords(word_delimiter=" "),
83
+ ]
84
+ )
85
+
86
+ # WPM part
87
+ processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
88
+ phoneme_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
89
+
90
+
91
+ class ChangeSampleRate(nn.Module):
92
+ def __init__(self, input_rate: int, output_rate: int):
93
+ super().__init__()
94
+ self.output_rate = output_rate
95
+ self.input_rate = input_rate
96
+
97
+ def forward(self, wav: torch.tensor) -> torch.tensor:
98
+ # Only accepts 1-channel waveform input
99
+ wav = wav.view(wav.size(0), -1)
100
+ new_length = wav.size(-1) * self.output_rate // self.input_rate
101
+ indices = torch.arange(new_length) * (self.input_rate / self.output_rate)
102
+ round_down = wav[:, indices.long()]
103
+ round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
104
+ output = round_down * (1.0 - indices.fmod(1.0)).unsqueeze(0) + (
105
+ round_up * indices.fmod(1.0).unsqueeze(0)
106
+ )
107
+ return output
108
+
109
+
110
+ # MOS model
111
+ model = lightning_module.BaselineLightningModule.load_from_checkpoint(
112
+ "src/epoch=3-step=7459.ckpt"
113
+ ).eval()
114
+
115
+ # Get Speech Interval
116
+
117
+ def get_speech_interval(signal, db):
118
+ audio_interv = librosa.effects.split(signal, top_db=db)
119
+ pause_end = [x[0] for x in audio_interv[1:]]
120
+ pause_start = [x[1] for x in audio_interv[0:-1]]
121
+ pause_interv = [[x, y] for x, y in zip(pause_start, pause_end)]
122
+ return audio_interv, pause_interv
123
+
124
+ # plot UV
125
+
126
+
127
+ def plot_UV(signal, audio_interv, sr):
128
+ fig, ax = plt.subplots(nrows=2, sharex=True)
129
+ librosa.display.waveshow(signal, sr=sr, ax=ax[0])
130
+ uv_flag = np.zeros(len(signal))
131
+ for i in audio_interv:
132
+ uv_flag[i[0] : i[1]] = 1
133
+
134
+ ax[1].plot(np.arange(len(signal)) / sr, uv_flag, "r")
135
+ ax[1].set_ylim([-0.1, 1.1])
136
+ return fig
137
+
138
+ def calc_mos(_, audio_path, id, ref, pre_ppm, fig=None):
139
+ if audio_path == None:
140
+ audio_path = _
141
+ print("using ref audio as eval audio since it's empty")
142
+
143
+ wav, sr = torchaudio.load(audio_path)
144
+ if wav.shape[0] != 1:
145
+ wav = wav[0, :]
146
+ print(wav.shape)
147
+
148
+ osr = 16000
149
+ batch = wav.unsqueeze(0).repeat(10, 1, 1)
150
+ csr = ChangeSampleRate(sr, osr)
151
+ out_wavs = csr(wav)
152
+
153
+ # ASR
154
+ trans = jiwer.ToLowerCase()(p(audio_path)["text"])
155
+
156
+ # WER
157
+ wer = jiwer.wer(
158
+ ref,
159
+ trans,
160
+ truth_transform=transformation,
161
+ hypothesis_transform=transformation,
162
+ )
163
+ # MOS
164
+ batch = {
165
+ "wav": out_wavs,
166
+ "domains": torch.tensor([0]),
167
+ "judge_id": torch.tensor([288]),
168
+ }
169
+ with torch.no_grad():
170
+ output = model(batch)
171
+ predic_mos = output.mean(dim=1).squeeze().detach().numpy() * 2 + 3
172
+
173
+ # Phonemes per minute (PPM)
174
+ with torch.no_grad():
175
+ logits = phoneme_model(out_wavs).logits
176
+ phone_predicted_ids = torch.argmax(logits, dim=-1)
177
+ phone_transcription = processor.batch_decode(phone_predicted_ids)
178
+ lst_phonemes = phone_transcription[0].split(" ")
179
+
180
+ # VAD for pause detection
181
+ wav_vad = torchaudio.functional.vad(wav, sample_rate=sr)
182
+ # pdb.set_trace()
183
+ a_h, p_h = get_speech_interval(wav_vad.numpy(), db=40)
184
+ # print(a_h)
185
+ # print(len(a_h))
186
+ fig_h = plot_UV(wav_vad.numpy().squeeze(), a_h, sr=sr)
187
+ ppm = len(lst_phonemes) / (wav_vad.shape[-1] / sr) * 60
188
+
189
+ error_msg = "!!! ERROR MESSAGE !!!\n"
190
+ if audio_path == _ or audio_path == None:
191
+ error_msg += "ERROR: Fail recording, Please start from the beginning again."
192
+ return (
193
+ fig_h,
194
+ predic_mos,
195
+ trans,
196
+ wer,
197
+ phone_transcription,
198
+ ppm,
199
+ error_msg,
200
+ )
201
+ if ppm >= float(pre_ppm) + float(config["thre"]["maxppm"]):
202
+ error_msg += "ERROR: Please speak slower.\n"
203
+ elif ppm <= float(pre_ppm) - float(config["thre"]["minppm"]):
204
+ error_msg += "ERROR: Please speak faster.\n"
205
+ elif predic_mos <= float(config["thre"]["AUTOMOS"]):
206
+ error_msg += "ERROR: Naturalness is too low, Please try again.\n"
207
+ elif wer >= float(config["thre"]["WER"]):
208
+ error_msg += "ERROR: Intelligibility is too low, Please try again\n"
209
+ else:
210
+ error_msg = (
211
+ "GOOD JOB! Please 【Save the Recording】.\nYou can start recording the next sample."
212
+ )
213
+
214
+ return (
215
+ fig_h,
216
+ predic_mos,
217
+ trans,
218
+ wer,
219
+ phone_transcription,
220
+ ppm,
221
+ error_msg,
222
+ )
223
+
224
+ with open("src/description.html", "r", encoding="utf-8") as f:
225
+ description = f.read()
226
+ # description
227
+
228
+ refs_ppm = np.array(ref_feature[:, -1][1:], dtype="str")
229
+
230
+ reference_id = gr.Textbox(value="ID", placeholder="Utter ID", label="Reference_ID", visible=False)
231
+ reference_textbox = gr.Textbox(
232
+ value="Input reference here",
233
+ placeholder="Input reference here",
234
+ label="Reference",
235
+ )
236
+ reference_PPM = gr.Textbox(placeholder="Pneumatic Voice's PPM", label="Ref PPM", visible=False)
237
+
238
+ # Flagging setup
239
+
240
+ # Interface
241
+ # Participant Information
242
+ def record_part_info(name, gender, first_lng):
243
+ message = "Participant information is successfully collected."
244
+ id_str = "%s_%s_%s" % (name, gender[0], first_lng[0])
245
+
246
+ if name == None:
247
+ message = "ERROR: Name Information incomplete!"
248
+ id_str = "ERROR"
249
+
250
+ if gender == None:
251
+ message = "ERROR: Please select gender"
252
+ id_str = "ERROR"
253
+
254
+ if len(gender) > 1:
255
+ message = "ERROR: Please select one gender only"
256
+ id_str = "ERROR"
257
+
258
+ if first_lng == None:
259
+ message = "ERROR: Please select your english proficiency"
260
+ id_str = "ERROR"
261
+
262
+ if len(first_lng) > 1:
263
+ message = "ERROR: Please select one english proficiency only"
264
+ id_str = "ERROR"
265
+
266
+ return message, id_str
267
+
268
+
269
+ # information page not using now
270
+ name = gr.Textbox(placeholder="Name", label="Name")
271
+ gender = gr.CheckboxGroup(["Male", "Female"], label="gender")
272
+ first_lng = gr.CheckboxGroup(
273
+ [
274
+ "B1 Intermediate",
275
+ "B2: Upper Intermediate",
276
+ "C1: Advanced",
277
+ "C2: Proficient",
278
+ ],
279
+ label="English Proficiency (CEFR)",
280
+ )
281
+
282
+ msg = gr.Textbox(placeholder="Evaluation for valid participant", label="message")
283
+ id_str = gr.Textbox(placeholder="participant id", label="participant_id")
284
+
285
+ info = gr.Interface(
286
+ fn=record_part_info,
287
+ inputs=[name, gender, first_lng],
288
+ outputs=[msg, id_str],
289
+ title="Participant Information Page",
290
+ allow_flagging="never",
291
+ css="body {background-color: blue}",
292
+ )
293
+ # Experiment
294
+ if config["exp_id"] == None:
295
+ config["exp_id"] = Path(config_yaml).stem
296
+
297
+ ## This is the theme for the interface
298
+ css = """
299
+ .ref_text textarea {font-size: 40px !important}
300
+ .message textarea {font-size: 40px !important}
301
+ """
302
+
303
+ my_theme = gr.themes.Default().set(
304
+ button_primary_background_fill="#75DA99",
305
+ button_primary_background_fill_dark="#DEF2D7",
306
+ button_primary_text_color="black",
307
+ button_secondary_text_color="black",
308
+ )
309
+
310
+ # Callback for saving the recording
311
+ callback = gr.CSVLogger()
312
+
313
+ with gr.Blocks(css=css, theme=my_theme) as demo:
314
+ with gr.Column():
315
+ with gr.Row():
316
+ ref_audio = gr.Audio(
317
+ source="microphone",
318
+ type="filepath",
319
+ label="Reference_Audio",
320
+ container=True,
321
+ interactive=False,
322
+ visible=False,
323
+ )
324
+ with gr.Row():
325
+ eval_audio = gr.Audio(
326
+ source="microphone",
327
+ type="filepath",
328
+ container=True,
329
+ label="Audio_to_Evaluate",
330
+ )
331
+ b_redo = gr.ClearButton(
332
+ value="Redo", variant="stop", components=[eval_audio], size="sm"
333
+ )
334
+ reference_textbox = gr.Textbox(
335
+ value="Input reference here",
336
+ placeholder="Input reference here",
337
+ label="Reference",
338
+ interactive=True,
339
+ elem_classes="ref_text",
340
+ )
341
+ with gr.Accordion("Input for Development", open=False):
342
+ reference_id = gr.Textbox(
343
+ value="ID",
344
+ placeholder="Utter ID",
345
+ label="Reference_ID",
346
+ visible=True,
347
+ )
348
+ reference_PPM = gr.Textbox(
349
+ placeholder="Pneumatic Voice's PPM",
350
+ label="Ref PPM",
351
+ visible=True,
352
+ )
353
+ with gr.Row():
354
+ b = gr.Button(value="1.Submit", variant="primary", elem_classes="submit")
355
+
356
+ # TODO
357
+ # b_more = gr.Button(value="Show More", elem_classes="verbose")
358
+ with gr.Row():
359
+ inputs = [
360
+ ref_audio,
361
+ eval_audio,
362
+ reference_id,
363
+ reference_textbox,
364
+ reference_PPM,
365
+ ]
366
+ e = gr.Examples(examples, inputs, examples_per_page=5)
367
+
368
+ with gr.Column():
369
+ with gr.Row():
370
+ ## output block
371
+ msg = gr.Textbox(
372
+ placeholder="Recording Feedback",
373
+ label="Message",
374
+ interactive=False,
375
+ elem_classes="message",
376
+ )
377
+ with gr.Accordion("Output for Development", open=False):
378
+ wav_plot = gr.Plot(PlaceHolder="Wav/Pause Plot", label="wav_pause_plot", visible=True)
379
+
380
+ predict_mos = gr.Textbox(
381
+ placeholder="Predicted MOS",
382
+ label="Predicted MOS",
383
+ visible=True,
384
+ )
385
+
386
+ hyp = gr.Textbox(placeholder="Hypothesis", label="Hypothesis", visible=True)
387
+
388
+ wer = gr.Textbox(placeholder="Word Error Rate", label="WER", visible=True)
389
+
390
+ predict_pho = gr.Textbox(
391
+ placeholder="Predicted Phonemes",
392
+ label="Predicted Phonemes",
393
+ visible=True,
394
+ )
395
+
396
+ ppm = gr.Textbox(
397
+ placeholder="Phonemes per minutes",
398
+ label="PPM",
399
+ visible=True,
400
+ )
401
+ outputs = [
402
+ wav_plot,
403
+ predict_mos,
404
+ hyp,
405
+ wer,
406
+ predict_pho,
407
+ ppm,
408
+ msg,
409
+ ]
410
+
411
+ # b = gr.Button("Submit")
412
+ b.click(fn=calc_mos, inputs=inputs, outputs=outputs, api_name="Submit")
413
+
414
+ # Logger
415
+ callback.setup(
416
+ components=[
417
+ eval_audio,
418
+ reference_id,
419
+ reference_textbox,
420
+ reference_PPM,
421
+ predict_mos,
422
+ hyp,
423
+ wer,
424
+ ppm,
425
+ msg],
426
+ flagging_dir="./exp/%s" % config["exp_id"],
427
+ )
428
+
429
+ with gr.Row():
430
+ b2 = gr.Button("2. Save the Recording", variant="primary", elem_id="save")
431
+ js_confirmed_saving = "(x) => confirm('Recording Saved!')"
432
+ # eval_audio,
433
+ b2.click(
434
+ lambda *args: callback.flag(args),
435
+ inputs=[
436
+ eval_audio,
437
+ reference_id,
438
+ reference_textbox,
439
+ reference_PPM,
440
+ predict_mos,
441
+ hyp,
442
+ wer,
443
+ ppm,
444
+ msg,
445
+ ],
446
+ outputs=None,
447
+ preprocess=False,
448
+ api_name="flagging",
449
+ )
450
+ with gr.Row():
451
+ b3 = gr.ClearButton(
452
+ [
453
+ ref_audio,
454
+ eval_audio,
455
+ reference_id,
456
+ reference_textbox,
457
+ reference_PPM,
458
+ predict_mos,
459
+ hyp,
460
+ wer,
461
+ ppm,
462
+ msg,
463
+ ],
464
+ value="3.Clear All",
465
+ elem_id="clear",
466
+ )
467
+
468
+ demo.launch(share=True)
local/UV.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import librosa.display
3
+ import matplotlib.pyplot as plt
4
+
5
+ # Plot_UV
6
+
7
+
8
+ def plot_UV(signal, audio_interv, sr):
9
+ fig, ax = plt.subplots(nrows=2, sharex=True)
10
+ librosa.display.waveshow(signal, sr=sr, ax=ax[0])
11
+ ax[0].set_title("Signal")
12
+ ax[1].set_title("U/V")
13
+ uv_flag = np.zeros(len(signal))
14
+ for i in audio_interv:
15
+ uv_flag[i[0]: i[1]] = 1
16
+
17
+ ax[1].plot(np.arange(len(signal))/sr, uv_flag, "r")
18
+ ax[1].set_ylim([-0.1, 1.1])
19
+ return fig
20
+
21
+ # Get Speech Interval
22
+
23
+
24
+ def get_speech_interval(signal, db):
25
+ audio_interv = librosa.effects.split(signal, top_db=db)
26
+ pause_end = [x[0] for x in audio_interv[1:]]
27
+ pause_start = [x[1] for x in audio_interv[0: -1]]
28
+ pause_interv = [[x, y] for x, y in zip(pause_start, pause_end)]
29
+ return audio_interv, pause_interv
local/VCTK_preprocessing.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Kevin @ Laronix Dec. 2022
2
+ # Data processing at Laronix
3
+ import csv
4
+ import soundfile as sf
5
+ import pandas as pd
6
+ from pathlib import Path
7
+ import librosa
8
+ import sys
9
+ import numpy as np
10
+ import pdb
11
+ from rich.progress import track
12
+
13
+ wavdir = sys.argv[1]
14
+ txtdir = sys.argv[2]
15
+ thre_len = int(sys.argv[3])
16
+ origin_sr = int(sys.argv[4])
17
+ target_sr = int(sys.argv[5])
18
+
19
+ wavs = sorted(Path(wavdir).glob("**/*.wav"))
20
+ txts = sorted(Path(txtdir).glob("**/*.txt"))
21
+ target_dir = "./data/%s_%d_%d_len%d" % (
22
+ Path(wavdir).stem,
23
+ origin_sr,
24
+ target_sr,
25
+ thre_len,
26
+ )
27
+
28
+ Path.mkdir(Path(target_dir), exist_ok=True)
29
+ # pdb.set_trace()
30
+ tables = []
31
+ for x, y in track(
32
+ zip(wavs, txts), description="Processing...", total=len(wavs)
33
+ ):
34
+ label = 1
35
+ with open(y, "r") as f:
36
+ txt = f.readline()
37
+ if len(txt.split(" ")) <= thre_len:
38
+ label = 1
39
+ record = [x, Path(x).stem, txt, len(txt.split(" ")), label]
40
+ tables.append(record)
41
+ # Select length <= 10 words sentences for training
42
+ if len(txt.split(" ")) <= thre_len:
43
+ wav, sr = librosa.load(x, sr=origin_sr)
44
+ wav_ = librosa.resample(wav, orig_sr=sr, target_sr=target_sr)
45
+ sf.write(
46
+ Path(target_dir) / Path((x).stem + ".wav"),
47
+ data=wav_,
48
+ samplerate=target_sr,
49
+ )
50
+
51
+ D = pd.DataFrame(
52
+ tables, columns=["wav_path", "id", "text", "len", "length_label"]
53
+ )
54
+ D.to_csv(target_dir + ".datalog", sep=",")
55
+ print("Check data log at %s" % (target_dir + ".datalog"))
56
+
57
+ D.get(["id", "text"]).to_csv(
58
+ target_dir + ".txt", sep="\t", header=False, index=False, quoting=3
59
+ )
60
+
61
+ print("Generate id_text at %s" % (target_dir + ".txt"))
62
+ print("Finish")
local/data_preparation.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pdb
3
+ import shutil
4
+ import pandas as pd
5
+ from datasets import Dataset, load_dataset
6
+
7
+ audio_dir = "./data/Patient_sil_trim_16k_normed_5_snr_40/"
8
+ # split_files = {"train": "data/Patient_sil_trim_16k_normed_5_snr_40/train.csv",
9
+ # "test": "data/Patient_sil_trim_16k_normed_5_snr_40/test.csv",
10
+ # "dev": "data/Patient_sil_trim_16k_normed_5_snr_40/dev.csv"}
11
+ src_dataset = load_dataset("audiofolder", data_dir=audio_dir, split="train")
12
+ pdb.set_trace()
13
+ def train_dev_test_split(
14
+ dataset: Dataset, dev_rate=0.1, test_rate=0.1, seed=1, metadata_output=False, root_dir=None
15
+ ):
16
+ """
17
+ input: dataset
18
+ dev_rate,
19
+ test_rate
20
+ seed
21
+ -------
22
+ Output:
23
+ dataset_dict{"train", "dev", "test"}
24
+ """
25
+ train_dev_test = dataset.train_test_split(test_size=test_rate, seed=seed)
26
+ test = train_dev_test["test"]
27
+ train_dev = train_dev_test["train"]
28
+
29
+ if len(train_dev) <= int(len(dataset) * dev_rate):
30
+ train = Dataset.from_dict({"audio": [], "transcription": []})
31
+ dev = train_dev
32
+ else:
33
+ train_dev = train_dev.train_test_split(
34
+ test_size=int(len(dataset) * dev_rate), seed=seed
35
+ )
36
+ train = train_dev["train"]
37
+ dev = train_dev["test"]
38
+
39
+ train_size = len(train)
40
+ dev_size = len(dev)
41
+ test_size = len(test)
42
+
43
+ print(f"Train Size: {len(train)}")
44
+ print(f"Dev Size: {len(dev)}")
45
+ print(f"Test Size: {len(test)}")
46
+ import pdb
47
+ if metadata_output:
48
+ pdb.set_trace()
49
+ train_df = pd.DateFrame(train)
50
+ dev_df = pd.DataFrame(dev)
51
+ test_df = pd.DataFrame(test)
52
+
53
+ try:
54
+ os.path.exists(root_dir)
55
+ except:
56
+ raise FileNotFoundError
57
+
58
+ # Create directories for train, dev, and test data
59
+ import pdb
60
+ if not os.path.exists(f'{root_dir}/train'):
61
+ os.makedirs(f'{root_dir}/train')
62
+ if not os.path.exists(f'{root_dir}/dev'):
63
+ os.makedirs(f'{root_dir}/dev')
64
+ if not os.path.exists(f'{root_dir}/test'):
65
+ os.makedirs(f'{root_dir}/test')
66
+
67
+ pdb.set_trace()
68
+ train_df.to_csv(f'{root_dir}/train/metadata.csv', index=False)
69
+
70
+ dev_df.to_csv(f'{root_dir}/dev/metadata.csv', index=False)
71
+
72
+ test_df.to_csv(f'{root_dir}/test/metadata.csv', index=False)
73
+
74
+ return train, dev, test
75
+
76
+ train, dev, test = train_dev_test_split(src_dataset, dev_rate=0.1, test_rate=0.1, seed=1, metadata_output=True, root_dir=audio_dir)
77
+
78
+ pdb.set_trace()
local/decode.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fine_tuning_dir = "fine_tuned/SSD/model/Negel_79_AVA_script_conv_train_conv_dev/checkpoint-50"
2
+
3
+ from typing import Any, Dict, List, Union
4
+ from dataclasses import dataclass
5
+ from transformers import Seq2SeqTrainer
6
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration, WhisperTokenizer, WhisperFeatureExtractor, Seq2SeqTrainingArguments, Seq2SeqTrainer, WhisperModel
7
+ import evaluate
8
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
9
+ from random import sample
10
+ from sys import flags
11
+ import gradio as gr
12
+ import torchaudio
13
+ import torch.nn as nn
14
+ import jiwer
15
+ import numpy as np
16
+ from rich import print as rprint
17
+ from rich.progress import track
18
+ from transformers import pipeline
19
+ import argparse
20
+ import yaml
21
+ import torch
22
+ from pathlib import Path
23
+ from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC, AutoProcessor
24
+ from datasets import load_dataset, concatenate_datasets
25
+ from datasets import Dataset, Audio
26
+ import pdb
27
+ import string
28
+ import librosa
29
+ # local import
30
+ import sys
31
+
32
+ sys.path.append("src")
33
+ import lightning_module
34
+
35
+ torch.cuda.set_device("cuda:0")
36
+
37
+ audio_dir = "./data/Patient_sil_trim_16k_normed_5_snr_40"
38
+ healthy_dir = "./data/Healthy"
39
+ Fary_PAL_30 = "./data/Fary_PAL_p326_20230110_30"
40
+ John_p326 = "./data/John_p326/output"
41
+ John_video = "./data/20230103_video"
42
+ negel_79 = "./data/4_negel_79"
43
+
44
+ patient_T = "data/Patient_T/Patient_T"
45
+ patient_L = "data/Patient_L/Patient_L"
46
+ # Get Transcription, WER and PPM
47
+ """
48
+ TODO:
49
+ [DONE]: Automatic generating Config
50
+ """
51
+
52
+
53
+ sys.path.append("./src")
54
+
55
+
56
+ wer = evaluate.load("wer")
57
+
58
+ # root_path = Path(__file__).parents[1]
59
+
60
+
61
+ class ChangeSampleRate(nn.Module):
62
+ def __init__(self, input_rate: int, output_rate: int):
63
+ super().__init__()
64
+ self.output_rate = output_rate
65
+ self.input_rate = input_rate
66
+
67
+ def forward(self, wav: torch.tensor) -> torch.tensor:
68
+ # Only accepts 1-channel waveform input
69
+ wav = wav.view(wav.size(0), -1)
70
+ new_length = wav.size(-1) * self.output_rate // self.input_rate
71
+ indices = torch.arange(new_length) * (
72
+ self.input_rate / self.output_rate
73
+ )
74
+ round_down = wav[:, indices.long()]
75
+ round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
76
+ output = round_down * (1.0 - indices.fmod(1.0)).unsqueeze(
77
+ 0
78
+ ) + round_up * indices.fmod(1.0).unsqueeze(0)
79
+ return output
80
+
81
+ # resample and clean text data
82
+
83
+
84
+ def dataclean(example):
85
+ # pdb.set_trace()
86
+ if example['audio']['sampling_rate'] != 16000:
87
+ resampled_audio = librosa.resample(y=example['audio']['array'],
88
+ orig_sr=example['audio']['sampling_rate'],
89
+ target_sr=16000)
90
+ # torchaudio.transforms.Resample(example['audio']['sampling_rate'], 16000)
91
+ # resampled_audio = resampler(example['audio']['array'])
92
+
93
+ return {"audio": {"path": example['audio']['path'], "array": resampled_audio, "sampling_rate": 16000},
94
+ "transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))}
95
+ else:
96
+ return {"transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))}
97
+
98
+ processor = AutoFeatureExtractor.from_pretrained(
99
+ "facebook/wav2vec2-base-960h"
100
+ )
101
+
102
+ def prepare_dataset(batch):
103
+ audio = batch["audio"]
104
+ batch = processor(
105
+ audio["array"], sampling_rate=audio["sampling_rate"], text=batch['transcription'])
106
+ batch["input_length"] = len(batch["input_values"][0])
107
+ return batch
108
+
109
+
110
+ negel_79_dataset = load_dataset("audiofolder", data_dir=negel_79, split="train")
111
+ negel_79_dataset = negel_79_dataset.map(dataclean)
112
+
113
+ def train_dev_test_split(dataset: Dataset, dev_rate=0.1, test_rate=0.1, seed=1):
114
+ """
115
+ input: dataset
116
+ dev_rate,
117
+ test_rate
118
+ seed
119
+ -------
120
+ Output:
121
+ dataset_dict{"train", "dev", "test"}
122
+ """
123
+ train_dev_test = dataset.train_test_split(test_size=test_rate, seed=seed)
124
+ test = train_dev_test["test"]
125
+ train_dev = train_dev_test['train']
126
+
127
+ # pdb.set_trace()
128
+ if len(train_dev) <= int(len(dataset)*dev_rate):
129
+ train = Dataset.from_dict({"audio": [], "transcription": []})
130
+ dev = train_dev
131
+ else:
132
+ train_dev = train_dev.train_test_split(test_size=int(len(dataset)*dev_rate), seed=seed)
133
+ train = train_dev['train']
134
+ dev = train_dev['test']
135
+ return train, dev, test
136
+
137
+ # pdb.set_trace()
138
+ # P1tony_train, P1tony_dev, P1tony_test = train_dev_test_split(P1tony_dataset, dev_rate=0.5, test_rate=0.5, seed=1)
139
+ # P1tony_train_ = concatenate_datasets([P1tony_train,P1tony_scripted])
140
+ # pdb.set_trace()
141
+
142
+ Negel_79_train, Negel_79_dev, Negel_79_test = train_dev_test_split(negel_79_dataset, dev_rate=0.1, test_rate=0.1, seed=1)
143
+
144
+ # src_dataset = load_dataset("audiofolder", data_dir=audio_dir, split="train")
145
+ # src_dataset = src_dataset.map(dataclean)
146
+
147
+ # healthy_test_dataset = load_dataset(
148
+ # "audiofolder", data_dir=healthy_dir, split='train')
149
+ # healthy_test_dataset = healthy_test_dataset.map(dataclean)
150
+
151
+ # Fary_PAL_test_dataset = load_dataset(
152
+ # "audiofolder", data_dir=Fary_PAL_30, split='train')
153
+ # Fary_PAL_test_dataset = Fary_PAL_test_dataset.map(dataclean)
154
+
155
+ # John_p326_test_dataset = load_dataset(
156
+ # "audiofolder", data_dir=John_p326, split='train')
157
+ # John_p326_test_dataset = John_p326_test_dataset.map(dataclean)
158
+
159
+ # John_video_test_dataset = load_dataset(
160
+ # "audiofolder", data_dir=John_video, split='train')
161
+ # John_video_test_dataset = John_video_test_dataset.map(dataclean)
162
+
163
+ # patient_T_test_dataset = load_dataset("audiofolder", data_dir=patient_T, split='train')
164
+ # patient_T_test_dataset = patient_T_test_dataset.map(dataclean)
165
+
166
+ # patient_L_test_dataset = load_dataset("audiofolder", data_dir=patient_L, split='train')
167
+ # patient_L_test_dataset = patient_L_test_dataset.map(dataclean)
168
+ # pdb.set_trace()
169
+
170
+ # train_dev / test
171
+ # ds = src_dataset.train_test_split(test_size=0.1, seed=1)
172
+
173
+ # dataset_libri = load_dataset(
174
+ # "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
175
+
176
+ # train_dev = ds['train']
177
+ # # train / dev
178
+ # train_dev = train_dev.train_test_split(
179
+ # test_size=int(len(src_dataset)*0.1), seed=1)
180
+ # # train/dev/test
181
+ # train = train_dev['train']
182
+ # test = ds['test']
183
+ # dev = train_dev['test']
184
+
185
+ # # pdb.set_trace()
186
+ # encoded_train = train.map(prepare_dataset, num_proc=4)
187
+ # encoded_dev = dev.map(prepare_dataset, num_proc=4)
188
+ # encoded_test = test.map(prepare_dataset, num_proc=4)
189
+
190
+ # encoded_healthy = healthy_test_dataset.map(prepare_dataset, num_proc=4)
191
+ # encoded_Fary = Fary_PAL_test_dataset.map(prepare_dataset, num_proc=4)
192
+ # encoded_John_p326 = John_p326_test_dataset.map(prepare_dataset, num_proc=4)
193
+ # encoded_John_video = John_video_test_dataset.map(prepare_dataset, num_proc=4)
194
+ # pdb.set_trace()
195
+
196
+ WER = evaluate.load("wer")
197
+
198
+ # Whisper decoding
199
+
200
+ processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
201
+ model = WhisperForConditionalGeneration.from_pretrained(
202
+ "openai/whisper-medium").to("cuda:0")
203
+ tokenizer = WhisperTokenizer.from_pretrained(
204
+ "openai/whisper-medium", language="English", task="transcribe")
205
+
206
+ # Need to push tokenizer to hugginface/model to activate online API
207
+
208
+ # tokenizer.push_to_hub("KevinGeng/whipser_medium_en_PAL300_step25")
209
+ # import pdb
210
+ # pdb.set_trace()
211
+
212
+ feature_extractor = WhisperFeatureExtractor.from_pretrained(
213
+ "openai/whisper-medium")
214
+
215
+
216
+ def whisper_prepare_dataset(batch):
217
+ # load and resample audio data from 48 to 16kHz
218
+ audio = batch["audio"]
219
+
220
+ # compute log-Mel input features from input audio array
221
+ batch["input_features"] = feature_extractor(
222
+ audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
223
+
224
+ # encode target text to label ids
225
+ batch["labels"] = tokenizer(batch["transcription"]).input_ids
226
+ return batch
227
+
228
+
229
+ torch.cuda.empty_cache()
230
+
231
+ training_args = Seq2SeqTrainingArguments(
232
+ # change to a repo name of your choice
233
+ output_dir="./whisper-medium-PAL128-25step",
234
+ per_device_train_batch_size=8,
235
+ gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
236
+ learning_rate=1e-5,
237
+ warmup_steps=100,
238
+ max_steps=1000,
239
+ gradient_checkpointing=True,
240
+ fp16=True,
241
+ evaluation_strategy="steps",
242
+ per_device_eval_batch_size=8,
243
+ predict_with_generate=True,
244
+ generation_max_length=512,
245
+ save_steps=100,
246
+ eval_steps=25,
247
+ logging_steps=100,
248
+ report_to=["tensorboard"],
249
+ load_best_model_at_end=True,
250
+ metric_for_best_model="wer",
251
+ greater_is_better=False,
252
+ push_to_hub=True,
253
+ )
254
+
255
+
256
+ def my_map_to_pred(batch):
257
+ # pdb.set_trace()
258
+ audio = batch["audio"]
259
+ input_features = processor(
260
+ audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt").input_features
261
+ # batch["reference"] = whisper_processor.tokenizer._normalize(batch['text'])
262
+ batch["reference"] = processor.tokenizer._normalize(batch['transcription'])
263
+
264
+ with torch.no_grad():
265
+ # predicted_ids = whisper_model.generate(input_features.to("cuda"))[0]
266
+ predicted_ids = model.generate(input_features.to("cuda"))[0]
267
+ transcription = model.decode(predicted_ids)
268
+ batch["prediction"] = model.tokenizer._normalize(transcription)
269
+ return batch
270
+
271
+
272
+ @dataclass
273
+ class DataCollatorSpeechSeq2SeqWithPadding:
274
+ processor: Any
275
+
276
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
277
+ # split inputs and labels since they have to be of different lengths and need different padding methods
278
+ # first treat the audio inputs by simply returning torch tensors
279
+ input_features = [{"input_features": feature["input_features"]}
280
+ for feature in features]
281
+ batch = self.processor.feature_extractor.pad(
282
+ input_features, return_tensors="pt")
283
+
284
+ # get the tokenized label sequences
285
+ label_features = [{"input_ids": feature["labels"]}
286
+ for feature in features]
287
+ # pad the labels to max length
288
+ labels_batch = self.processor.tokenizer.pad(
289
+ label_features, return_tensors="pt")
290
+
291
+ # replace padding with -100 to ignore loss correctly
292
+ labels = labels_batch["input_ids"].masked_fill(
293
+ labels_batch.attention_mask.ne(1), -100)
294
+
295
+ # if bos token is appended in previous tokenization step,
296
+ # cut bos token here as it's append later anyways
297
+ if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
298
+ labels = labels[:, 1:]
299
+
300
+ batch["labels"] = labels
301
+
302
+ return batch
303
+
304
+
305
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
306
+
307
+
308
+ def compute_metrics(pred):
309
+ pdb.set_trace()
310
+ pred_ids = pred.predictions
311
+ label_ids = pred.label_ids
312
+
313
+ # replace -100 with the pad_token_id
314
+ label_ids[label_ids == -100] = tokenizer.pad_token_id
315
+
316
+ # we do not want to group tokens when computing the metrics
317
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
318
+ label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
319
+
320
+ wer = 100 * WER.compute(predictions=pred_str, references=label_str)
321
+
322
+ return {"wer": wer}
323
+
324
+ encode_negel_79_train = Negel_79_train.map(whisper_prepare_dataset, num_proc=4)
325
+ encode_negel_79_dev = Negel_79_dev.map(whisper_prepare_dataset, num_proc=4)
326
+ encode_negel_79_test = Negel_79_test.map(whisper_prepare_dataset, num_proc=4)
327
+ pdb.set_trace()
328
+ torch.cuda.empty_cache()
329
+
330
+ torch.cuda.empty_cache()
331
+
332
+ fine_tuned_model = WhisperForConditionalGeneration.from_pretrained(
333
+ fine_tuning_dir
334
+ ).to("cuda")
335
+ # "fine_tuned/SSD/model/whipser_medium_TEP_patient_T"
336
+ # "./fine_tuned/whipser_medium_en_PAL300_step25_step2_VCTK/checkpoint-400"
337
+ #"./fine_tuned/whipser_medium_en_PAL300_step25_step2_VCTK/checkpoint-200"
338
+
339
+
340
+ def fine_tuned_map_to_pred(batch):
341
+ # pdb.set_trace()
342
+ audio = batch["audio"]
343
+ input_features = processor(
344
+ audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt").input_features
345
+ # batch["reference"] = whisper_processor.tokenizer._normalize(batch['text'])
346
+ batch["reference"] = processor.tokenizer._normalize(batch['transcription'])
347
+
348
+ with torch.no_grad():
349
+ # predicted_ids = whisper_model.generate(input_features.to("cuda"))[0]
350
+ predicted_ids = fine_tuned_model.generate(input_features.to("cuda"))[0]
351
+ transcription = tokenizer.decode(predicted_ids)
352
+ batch["prediction"] = tokenizer._normalize(transcription)
353
+ return batch
354
+
355
+
356
+ # output_dir="./fine_tuned/whipser_medium_en_PAL300_step25_step2_VCTK/checkpoint-400",
357
+ testing_args = Seq2SeqTrainingArguments(
358
+ # change to a repo name of your choice
359
+ output_dir="fine_tuned/SSD/model/whipser_medium_TEP_patient_TL_TL",
360
+ per_device_train_batch_size=8,
361
+ gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
362
+ learning_rate=1e-5,
363
+ warmup_steps=100,
364
+ max_steps=1000,
365
+ gradient_checkpointing=True,
366
+ fp16=True,
367
+ evaluation_strategy="steps",
368
+ per_device_eval_batch_size=8,
369
+ predict_with_generate=True,
370
+ generation_max_length=512,
371
+ save_steps=100,
372
+ eval_steps=25,
373
+ logging_steps=100,
374
+ report_to=["tensorboard"],
375
+ load_best_model_at_end=True,
376
+ metric_for_best_model="wer",
377
+ greater_is_better=False,
378
+ push_to_hub=False,
379
+ )
380
+
381
+ predict_trainer = Seq2SeqTrainer(
382
+ args=testing_args,
383
+ model=fine_tuned_model,
384
+ data_collator=data_collator,
385
+ compute_metrics=compute_metrics,
386
+ tokenizer=processor.feature_extractor,
387
+ )
388
+
389
+ # trainer.train()
390
+ # fine tuned
391
+ # z_result = encoded_test.map(fine_tuned_map_to_pred)
392
+ pdb.set_trace()
393
+ z_result= encode_negel_79_test.map(fine_tuned_map_to_pred)
394
+ # 0.4692737430167598
395
+ z = WER.compute(references=z_result['reference'], predictions=z_result['prediction'])
396
+ # pdb.set_trace()
397
+ # z_hel_result = encoded_healthy.map(fine_tuned_map_to_pred)
398
+ # z_hel = WER.compute(references=z_hel_result['reference'], predictions=z_hel_result['prediction'])
399
+ # # 0.1591610117211598
400
+
401
+ # # pdb.set_trace()
402
+ # # z_fary_result = encoded_Fary.map(fine_tuned_map_to_pred)
403
+ # # z_far = WER.compute(references=z_fary_result['reference'], predictions=z_fary_result['prediction'])
404
+ # # 0.1791044776119403
405
+ # z_patient_LT = encoded_patient_TL_test.map(fine_tuned_map_to_pred)
406
+ # z_patient_LT_result = WER.compute(references=z_patient_LT['reference'], predictions=z_patient_LT['prediction'])
407
+ # z_patient_L = encoded_patient_L_test.map(fine_tuned_map_to_pred)
408
+ # z_patient_L_result = WER.compute(references=z_patient_L['reference'], predictions=z_patient_L['prediction'])
409
+ # z_patient_T = encoded_patient_T_test.map(fine_tuned_map_to_pred)
410
+ # z_patient_T_result = WER.compute(references=z_patient_T['reference'], predictions=z_patient_T['prediction'])
411
+
412
+ # # z_john_p326_result = encoded_John_p326.map(fine_tuned_map_to_pred)
413
+ # # pdb.set_trace()
414
+
415
+ # # z_john_p326 = WER.compute(references=z_john_p326_result['reference'], predictions=z_john_p326_result['prediction'])
416
+ # # 0.4648241206030151
417
+ pdb.set_trace()
418
+
419
+ # # y_John_video= fine_tuned_trainer.predict(encoded_John_video)
420
+ # # metrics={'test_loss': 2.665189743041992, 'test_wer': 0.7222222222222222, 'test_runtime': 0.1633, 'test_samples_per_second': 48.979, 'test_steps_per_second': 6.122})
421
+ # pdb.set_trace()
422
+
423
+ # p326 training
424
+ # metrics={'test_loss': 0.4804028868675232, 'test_wer': 0.21787709497206703, 'test_runtime': 0.3594, 'test_samples_per_second': 44.517, 'test_steps_per_second': 5.565})
425
+ # hel metrics={'test_loss': 1.6363693475723267, 'test_wer': 0.17951881554595928, 'test_runtime': 3.8451, 'test_samples_per_second': 41.611, 'test_steps_per_second': 5.201})
426
+ # Fary: metrics={'t est_loss': 1.4633615016937256, 'test_wer': 0.5572139303482587, 'test_runtime': 0.6627, 'test_samples_per_second': 45.27, 'test_steps_per_second': 6.036})
427
+ # p326 large: metrics={'test_loss': 0.6568527817726135, 'test_wer': 0.2889447236180904, 'test_runtime': 0.7169, 'test_samples_per_second': 51.613, 'test_steps_per_second': 6.975})
local/duration_calcutator.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import librosa
3
+
4
+ folder_path = "/home/kevingeng/Disk2/laronix/laronix_automos/data/Patient_sil_trim_16k_normed_5_snr_40/Sentences" # Replace with the path to your folder
5
+ total_duration = 0.0
6
+
7
+ # Iterate through all files in the folder
8
+ for filename in os.listdir(folder_path):
9
+ file_path = os.path.join(folder_path, filename)
10
+ if os.path.isfile(file_path):
11
+ try:
12
+ # Load the audio file and get its duration
13
+ audio_data, _ = librosa.load(file_path)
14
+ duration = librosa.get_duration(audio_data)
15
+ total_duration += duration
16
+ except Exception as e:
17
+ print(f"Error processing file '{filename}': {e}")
18
+
19
+ print(f"Total duration of audio files in the folder: {total_duration} seconds.")
local/fine-tuning.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TODO:
3
+ + [x] Load Configuration
4
+ + [ ] Multi ASR Engine
5
+ + [ ] Batch / Real Time support
6
+ """
7
+ from pathlib import Path
8
+ from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC, AutoProcessor
9
+ from datasets import load_dataset
10
+ from datasets import Dataset, Audio
11
+ import pdb
12
+ import string
13
+ # local import
14
+ import sys
15
+
16
+ sys.path.append("src")
17
+
18
+ # token_model = AutoModelForCTC.from_pretrained(
19
+ # "facebook/wav2vec2-base-960h"
20
+ # )
21
+
22
+ # ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
23
+
24
+ audio_path = "/Users/kevingeng/Laronix/Laronix_PAL_ASR_Offline_Plot/data/samples/3_Healthy1.wav"
25
+
26
+ audio_dir= "/Users/kevingeng/Laronix/laronix_automos/data/Patient_sil_trim_16k_normed_5_snr_40/"
27
+ # tgt_audio_dir= "/Users/kevingeng/Laronix/Dataset/Pneumatic/automos"
28
+
29
+ # src_audio_list = sorted(Path(src_audio_dir).glob("**/*.wav"))
30
+ # src_audio_list = [str(x) for x in src_audio_list]
31
+ # src_audio_dict = {"audio": src_audio_list}
32
+ # src_dataset = Dataset.from_dict(src_audio_dict).cast_column("audio", Audio())
33
+
34
+ # tgt_audio_list = sorted(Path(tgt_audio_dir).glob("**/*.wav"))
35
+ # tgt_audio_list = [str(x) for x in tgt_audio_list]
36
+ # tgt_audio_dict = {"audio": tgt_audio_list}
37
+ # tgt_dataset = Dataset.from_dict(tgt_audio_dict).cast_column("audio", Audio())
38
+
39
+ # Get Transcription, WER and PPM
40
+ """
41
+ TODO:
42
+ [DONE]: Automatic generating Config
43
+ """
44
+
45
+ import yaml
46
+ import argparse
47
+ import sys
48
+ from pathlib import Path
49
+
50
+ sys.path.append("./src")
51
+ import lightning_module
52
+ from UV import plot_UV, get_speech_interval
53
+ from transformers import pipeline
54
+ from rich.progress import track
55
+ from rich import print as rprint
56
+ import numpy as np
57
+ import jiwer
58
+ import pdb
59
+ import torch.nn as nn
60
+ import torch
61
+ import torchaudio
62
+ import gradio as gr
63
+ from sys import flags
64
+ from random import sample
65
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
66
+
67
+ # root_path = Path(__file__).parents[1]
68
+
69
+ class ChangeSampleRate(nn.Module):
70
+ def __init__(self, input_rate: int, output_rate: int):
71
+ super().__init__()
72
+ self.output_rate = output_rate
73
+ self.input_rate = input_rate
74
+
75
+ def forward(self, wav: torch.tensor) -> torch.tensor:
76
+ # Only accepts 1-channel waveform input
77
+ wav = wav.view(wav.size(0), -1)
78
+ new_length = wav.size(-1) * self.output_rate // self.input_rate
79
+ indices = torch.arange(new_length) * (
80
+ self.input_rate / self.output_rate
81
+ )
82
+ round_down = wav[:, indices.long()]
83
+ round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
84
+ output = round_down * (1.0 - indices.fmod(1.0)).unsqueeze(
85
+ 0
86
+ ) + round_up * indices.fmod(1.0).unsqueeze(0)
87
+ return output
88
+
89
+
90
+ model = lightning_module.BaselineLightningModule.load_from_checkpoint(
91
+ "./src/epoch=3-step=7459.ckpt"
92
+ ).eval()
93
+
94
+
95
+ def calc_wer(audio_path, ref):
96
+ wav, sr = torchaudio.load(audio_path)
97
+ osr = 16_000
98
+ batch = wav.unsqueeze(0).repeat(10, 1, 1)
99
+ csr = ChangeSampleRate(sr, osr)
100
+ out_wavs = csr(wav)
101
+ # ASR
102
+ trans = p(audio_path)["text"]
103
+ # WER
104
+ wer = jiwer.wer(
105
+ ref,
106
+ trans,
107
+ truth_transform=transformation,
108
+ hypothesis_transform=transformation,
109
+ )
110
+ return trans, wer
111
+
112
+ # if __name__ == "__main__":
113
+ # # Argparse
114
+ # parser = argparse.ArgumentParser(
115
+ # prog="get_ref_PPM",
116
+ # description="Generate Phoneme per Minute (and Voice/Unvoice plot)",
117
+ # epilog="",
118
+ # )
119
+ # parser.add_argument(
120
+ # "--tag",
121
+ # type=str,
122
+ # default=None,
123
+ # required=False,
124
+ # help="ID tag for output *.csv",
125
+ # )
126
+
127
+ # parser.add_argument("--ref_txt", type=str, required=True, help="Reference TXT")
128
+ # parser.add_argument(
129
+ # "--ref_wavs", type=str, required=True, help="Reference WAVs"
130
+ # )
131
+
132
+ # parser.add_argument(
133
+ # "--output_dir",
134
+ # type=str,
135
+ # required=True,
136
+ # help="Output Directory for *.csv",
137
+ # )
138
+ # parser.add_argument(
139
+ # "--to_config",
140
+ # choices=["True", "False"],
141
+ # default="False",
142
+ # help="Generating Config from .txt and wavs/*wav",
143
+ # )
144
+
145
+
146
+ # args = parser.parse_args()
147
+
148
+ # refs = np.loadtxt(args.ref_txt, delimiter="\n", dtype="str")
149
+ # refs_ids = [x.split()[0] for x in refs]
150
+ # refs_txt = [" ".join(x.split()[1:]) for x in refs]
151
+ # ref_wavs = [str(x) for x in sorted(Path(args.ref_wavs).glob("**/*.wav"))]
152
+ # # pdb.set_trace()
153
+ # try:
154
+ # len(refs) == len(ref_wavs)
155
+ # except ValueError:
156
+ # print("Error: Text and Wavs don't match")
157
+ # exit()
158
+
159
+ # # ASR part
160
+ # p = pipeline("automatic-speech-recognition")
161
+
162
+ # # WER part
163
+ # transformation = jiwer.Compose(
164
+ # [
165
+ # jiwer.ToLowerCase(),
166
+ # jiwer.RemoveWhiteSpace(replace_by_space=True),
167
+ # jiwer.RemoveMultipleSpaces(),
168
+ # jiwer.ReduceToListOfListOfWords(word_delimiter=" "),
169
+ # ]
170
+ # )
171
+
172
+ # # WPM part
173
+ # processor = Wav2Vec2Processor.from_pretrained(
174
+ # "facebook/wav2vec2-xlsr-53-espeak-cv-ft"
175
+ # )
176
+ # phoneme_model = Wav2Vec2ForCTC.from_pretrained(
177
+ # "facebook/wav2vec2-xlsr-53-espeak-cv-ft"
178
+ # )
179
+ # # phoneme_model = pipeline(model="facebook/wav2vec2-xlsr-53-espeak-cv-ft")
180
+
181
+ # description = """
182
+ # MOS prediction demo using UTMOS-strong w/o phoneme encoder model, \
183
+ # which is trained on the main track dataset.
184
+ # This demo only accepts .wav format. Best at 16 kHz sampling rate.
185
+
186
+ # Paper is available [here](https://arxiv.org/abs/2204.02152)
187
+
188
+ # Add ASR based on wav2vec-960, currently only English available.
189
+ # Add WER interface.
190
+ # """
191
+
192
+ # referance_id = gr.Textbox(
193
+ # value="ID", placeholder="Utter ID", label="Reference_ID"
194
+ # )
195
+ # referance_textbox = gr.Textbox(
196
+ # value="", placeholder="Input reference here", label="Reference"
197
+ # )
198
+ # # Set up interface
199
+ # result = []
200
+ # result.append("id, trans, wer")
201
+
202
+
203
+ # for id, x, y in track(
204
+ # zip(refs_ids, ref_wavs, refs_txt),
205
+ # total=len(refs_ids),
206
+ # description="Loading references information",
207
+ # ):
208
+ # trans, wer = calc_wer(x, y)
209
+ # record = ",".join(
210
+ # [
211
+ # id,
212
+ # str(trans),
213
+ # str(wer)
214
+ # ]
215
+ # )
216
+ # result.append(record)
217
+
218
+ # # Output
219
+ # if args.tag == None:
220
+ # args.tag = Path(args.ref_wavs).stem
221
+ # # Make output_dir
222
+ # # pdb.set_trace()
223
+ # Path.mkdir(Path(args.output_dir), exist_ok=True)
224
+ # # pdb.set_trace()
225
+ # with open("%s/%s.csv" % (args.output_dir, args.tag), "w") as f:
226
+ # print("\n".join(result), file=f)
227
+
228
+ # # Generating config
229
+ # if args.to_config == "True":
230
+ # config_dict = {
231
+ # "exp_id": args.tag,
232
+ # "ref_txt": args.ref_txt,
233
+ # "ref_feature": "%s/%s.csv" % (args.output_dir, args.tag),
234
+ # "ref_wavs": args.ref_wavs,
235
+ # "thre": {
236
+ # "minppm": 100,
237
+ # "maxppm": 100,
238
+ # "WER": 0.1,
239
+ # "AUTOMOS": 4.0,
240
+ # },
241
+ # "auth": {"username": None, "password": None},
242
+ # }
243
+ # with open("./config/%s.yaml" % args.tag, "w") as config_f:
244
+ # rprint("Dumping as config ./config/%s.yaml" % args.tag)
245
+ # rprint(config_dict)
246
+ # yaml.dump(config_dict, stream=config_f)
247
+ # rprint("Change parameter ./config/%s.yaml if necessary" % args.tag)
248
+ # print("Reference Dumping Finished")
249
+ def dataclean(example):
250
+ return {"transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))}
251
+
252
+ # processor = AutoFeatureExtractor.from_pretrained(
253
+ # "facebook/wav2vec2-base-960h"
254
+ # )
255
+ processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base")
256
+
257
+ def prepare_dataset(batch):
258
+ audio = batch["audio"]
259
+ batch = processor(audio["array"], sampling_rate = audio["sampling_rate"], text=batch['transcription'])
260
+ batch["input_length"] = len(batch["input_values"][0])
261
+ return batch
262
+
263
+ src_dataset = load_dataset("audiofolder", data_dir=audio_dir, split="train")
264
+ src_dataset = src_dataset.map(dataclean)
265
+ # train_dev / test
266
+ ds = src_dataset.train_test_split(test_size=0.1)
267
+
268
+ train_dev = ds['train']
269
+ # train / dev
270
+ train_dev = train_dev.train_test_split(test_size=int(len(src_dataset)*0.1))
271
+ # train/dev/test
272
+ train = train_dev['train']
273
+ test = ds['test']
274
+ dev = train_dev['test']
275
+
276
+ # pdb.set_trace()
277
+ import numpy as np
278
+
279
+
280
+ def compute_metrics(pred):
281
+ pred_logits = pred.predictions
282
+ pred_ids = np.argmax(pred_logits, axis=-1)
283
+
284
+ pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
285
+
286
+ pred_str = processor.batch_decode(pred_ids)
287
+ label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
288
+
289
+ wer = wer.compute(predictions=pred_str, references=label_str)
290
+
291
+ return {"wer": wer}
292
+
293
+
294
+ pdb.set_trace()
295
+ # TOKENLIZER("data/samples/5_Laronix1.wav")
296
+ # pdb.set_trace()
297
+ # tokenizer
298
+ tokenizer = AutoTokenizer.from_pretrained("facebook/wav2vec2-base-960h")
299
+
300
+ encoded_train = train.map(prepare_dataset, num_proc=4)
301
+
302
+ from transformers import AutoModelForCTC, TrainingArguments, Trainer
303
+
304
+ model = AutoModelForCTC.from_pretrained(
305
+ "facebook/wav2vec2-base",
306
+ ctc_loss_reduction="mean",
307
+ pad_token_id=processor.tokenizer.pad_token_id,
308
+ )
309
+ pdb.set_trace()
310
+
311
+ training_args = TrainingArguments(
312
+ output_dir="my_awesome_asr_mind_model",
313
+ per_device_train_batch_size=8,
314
+ gradient_accumulation_steps=2,
315
+ learning_rate=1e-5,
316
+ warmup_steps=500,
317
+ max_steps=2000,
318
+ gradient_checkpointing=True,
319
+ fp16=True,
320
+ group_by_length=True,
321
+ evaluation_strategy="steps",
322
+ per_device_eval_batch_size=8,
323
+ save_steps=1000,
324
+ eval_steps=1000,
325
+ logging_steps=25,
326
+ load_best_model_at_end=True,
327
+ metric_for_best_model="wer",
328
+ greater_is_better=False,
329
+ push_to_hub=True,
330
+ )
331
+
332
+ pdb.set_trace()
333
+ trainer = Trainer(
334
+ model=model,
335
+ args=training_args,
336
+ train_dataset=encoded_train["train"],
337
+ eval_dataset=encoded_train["test"],
338
+ tokenizer=processor.feature_extractor,
339
+ compute_metrics=compute_metrics,
340
+ )
341
+ pdb.set_trace()
342
+ # data_collator=data_collator,
343
+
344
+ trainer.train()
345
+ # x = tokenizer(test['transcription'][0])
346
+
local/get_ASR.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Get Transcription, WER and PPM
2
+ """
3
+ TODO:
4
+ [DONE]: Automatic generating Config
5
+ """
6
+
7
+ import yaml
8
+ import argparse
9
+ import sys
10
+ from pathlib import Path
11
+
12
+ sys.path.append("./src")
13
+ import lightning_module
14
+ from UV import plot_UV, get_speech_interval
15
+ from transformers import pipeline
16
+ from rich.progress import track
17
+ from rich import print as rprint
18
+ import numpy as np
19
+ import jiwer
20
+ import pdb
21
+ import torch.nn as nn
22
+ import torch
23
+ import torchaudio
24
+ import gradio as gr
25
+ from sys import flags
26
+ from random import sample
27
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
28
+
29
+ # root_path = Path(__file__).parents[1]
30
+
31
+ class ChangeSampleRate(nn.Module):
32
+ def __init__(self, input_rate: int, output_rate: int):
33
+ super().__init__()
34
+ self.output_rate = output_rate
35
+ self.input_rate = input_rate
36
+
37
+ def forward(self, wav: torch.tensor) -> torch.tensor:
38
+ # Only accepts 1-channel waveform input
39
+ wav = wav.view(wav.size(0), -1)
40
+ new_length = wav.size(-1) * self.output_rate // self.input_rate
41
+ indices = torch.arange(new_length) * (
42
+ self.input_rate / self.output_rate
43
+ )
44
+ round_down = wav[:, indices.long()]
45
+ round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
46
+ output = round_down * (1.0 - indices.fmod(1.0)).unsqueeze(
47
+ 0
48
+ ) + round_up * indices.fmod(1.0).unsqueeze(0)
49
+ return output
50
+
51
+
52
+ model = lightning_module.BaselineLightningModule.load_from_checkpoint(
53
+ "./src/epoch=3-step=7459.ckpt"
54
+ ).eval()
55
+
56
+
57
+ def calc_wer(audio_path, ref, ASR_pipeline):
58
+ wav, sr = torchaudio.load(audio_path)
59
+ osr = 16_000
60
+ batch = wav.unsqueeze(0).repeat(10, 1, 1)
61
+ csr = ChangeSampleRate(sr, osr)
62
+ out_wavs = csr(wav)
63
+ # ASR
64
+ trans = ASR_pipeline(audio_path)["text"]
65
+ # WER
66
+ wer = jiwer.wer(
67
+ ref,
68
+ trans,
69
+ truth_transform=transformation,
70
+ hypothesis_transform=transformation,
71
+ )
72
+ return trans, wer
73
+
74
+ if __name__ == "__main__":
75
+ # Argparse
76
+ parser = argparse.ArgumentParser(
77
+ prog="get_ref_PPM",
78
+ description="Generate Phoneme per Minute (and Voice/Unvoice plot)",
79
+ epilog="",
80
+ )
81
+ parser.add_argument(
82
+ "--tag",
83
+ type=str,
84
+ default=None,
85
+ required=False,
86
+ help="ID tag for output *.csv",
87
+ )
88
+
89
+ parser.add_argument("--ref_txt", type=str, required=True, help="Reference TXT")
90
+ parser.add_argument(
91
+ "--ref_wavs", type=str, required=True, help="Reference WAVs"
92
+ )
93
+ parser.add_argument(
94
+ "--metadata",
95
+ type=str,
96
+ required=False,
97
+ help="metadata.csv including wav_id and reference",
98
+ )
99
+
100
+ parser.add_argument(
101
+ "--model",
102
+ type=str,
103
+ default='whisper-medium-FT',
104
+ choices=['wav2vec+ctc', 'whipser-medium-FT', 'whipser-large-v2'],
105
+ help="ASR engine for evaluation:\n ver1: wav2vec+ctc \n ver2: whipser-medium(Fined-tuned)\n ver3: whipser-large-v2",
106
+ )
107
+
108
+ parser.add_argument(
109
+ "--output_dir",
110
+ type=str,
111
+ required=True,
112
+ help="Output Directory for *.csv",
113
+ )
114
+
115
+ parser.add_argument(
116
+ "--to_config",
117
+ choices=["True", "False"],
118
+ default="False",
119
+ help="Generating Config from .txt and wavs/*wav",
120
+ )
121
+
122
+
123
+ args = parser.parse_args()
124
+
125
+ refs = np.loadtxt(args.ref_txt, delimiter="\n", dtype="str")
126
+ refs_ids = [x.split()[0] for x in refs]
127
+ refs_txt = [" ".join(x.split()[1:]) for x in refs]
128
+ ref_wavs = [str(x) for x in sorted(Path(args.ref_wavs).glob("**/*.wav"))]
129
+ # pdb.set_trace()
130
+ try:
131
+ len(refs) == len(ref_wavs)
132
+ except ValueError:
133
+ print("Error: Text and Wavs don't match")
134
+ exit()
135
+
136
+ # ASR part
137
+ if args.model== "whisper-medium-FT":
138
+ ASR_pipeline = pipeline("automatic-speech-recognition", model="KevinGeng/whipser_medium_en_PAL300_step25")
139
+ elif args.model == "wav2vec+ctc":
140
+ ASR_pipeline = pipeline("automatic-speech-recognition")
141
+ elif args.model == "whisper-large-v2":
142
+ ASR_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-large-v2")
143
+
144
+ # pdb.set_trace()
145
+ # WER part
146
+ transformation = jiwer.Compose(
147
+ [
148
+ jiwer.ToLowerCase(),
149
+ jiwer.RemoveWhiteSpace(replace_by_space=True),
150
+ jiwer.RemoveMultipleSpaces(),
151
+ jiwer.ReduceToListOfListOfWords(word_delimiter=" "),
152
+ ]
153
+ )
154
+
155
+ # WPM part
156
+ processor = Wav2Vec2Processor.from_pretrained(
157
+ "facebook/wav2vec2-xlsr-53-espeak-cv-ft"
158
+ )
159
+ phoneme_model = Wav2Vec2ForCTC.from_pretrained(
160
+ "facebook/wav2vec2-xlsr-53-espeak-cv-ft"
161
+ )
162
+ # phoneme_model = pipeline(model="facebook/wav2vec2-xlsr-53-espeak-cv-ft")
163
+
164
+ description = """
165
+ MOS prediction demo using UTMOS-strong w/o phoneme encoder model, \
166
+ which is trained on the main track dataset.
167
+ This demo only accepts .wav format. Best at 16 kHz sampling rate.
168
+
169
+ Paper is available [here](https://arxiv.org/abs/2204.02152)
170
+
171
+ Add ASR based on wav2vec-960, currently only English available.
172
+ Add WER interface.
173
+ """
174
+
175
+ referance_id = gr.Textbox(
176
+ value="ID", placeholder="Utter ID", label="Reference_ID"
177
+ )
178
+ referance_textbox = gr.Textbox(
179
+ value="", placeholder="Input reference here", label="Reference"
180
+ )
181
+ # Set up interface
182
+ result = []
183
+ result.append("id,ref,hyp,wer")
184
+
185
+
186
+ for id, x, y in track(
187
+ zip(refs_ids, ref_wavs, refs_txt),
188
+ total=len(refs_ids),
189
+ description="Loading references information",
190
+ ):
191
+ trans, wer = calc_wer(x, y, ASR_pipeline=ASR_pipeline)
192
+ record = ",".join(
193
+ [
194
+ id,
195
+ str(y),
196
+ str(trans),
197
+ str(wer)
198
+ ]
199
+ )
200
+ result.append(record)
201
+
202
+ # Output
203
+ if args.tag == None:
204
+ args.tag = Path(args.ref_wavs).stem
205
+ # Make output_dir
206
+ # pdb.set_trace()
207
+ Path.mkdir(Path(args.output_dir), exist_ok=True)
208
+ # pdb.set_trace()
209
+ with open("%s/%s.csv" % (args.output_dir, args.tag), "w") as f:
210
+ print("\n".join(result), file=f)
211
+
212
+ # Generating config
213
+ if args.to_config == "True":
214
+ config_dict = {
215
+ "exp_id": args.tag,
216
+ "ref_txt": args.ref_txt,
217
+ "ref_feature": "%s/%s.csv" % (args.output_dir, args.tag),
218
+ "ref_wavs": args.ref_wavs,
219
+ "thre": {
220
+ "minppm": 100,
221
+ "maxppm": 100,
222
+ "WER": 0.1,
223
+ "AUTOMOS": 4.0,
224
+ },
225
+ "auth": {"username": None, "password": None},
226
+ }
227
+ with open("./config/%s.yaml" % args.tag, "w") as config_f:
228
+ rprint("Dumping as config ./config/%s.yaml" % args.tag)
229
+ rprint(config_dict)
230
+ yaml.dump(config_dict, stream=config_f)
231
+ rprint("Change parameter ./config/%s.yaml if necessary" % args.tag)
232
+ print("Reference Dumping Finished")
local/get_ref_PPM.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Get Transcription, WER and PPM
2
+ """
3
+ TODO:
4
+ [DONE]: Automatic generating Config
5
+ """
6
+
7
+ import yaml
8
+ import argparse
9
+ import sys
10
+ from pathlib import Path
11
+
12
+ sys.path.append("./src")
13
+ import lightning_module
14
+ from UV import plot_UV, get_speech_interval
15
+ from transformers import pipeline
16
+ from rich.progress import track
17
+ from rich import print as rprint
18
+ import numpy as np
19
+ import jiwer
20
+ import pdb
21
+ import torch.nn as nn
22
+ import torch
23
+ import torchaudio
24
+ import gradio as gr
25
+ from sys import flags
26
+ from random import sample
27
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
28
+
29
+ # root_path = Path(__file__).parents[1]
30
+
31
+ class ChangeSampleRate(nn.Module):
32
+ def __init__(self, input_rate: int, output_rate: int):
33
+ super().__init__()
34
+ self.output_rate = output_rate
35
+ self.input_rate = input_rate
36
+
37
+ def forward(self, wav: torch.tensor) -> torch.tensor:
38
+ # Only accepts 1-channel waveform input
39
+ wav = wav.view(wav.size(0), -1)
40
+ new_length = wav.size(-1) * self.output_rate // self.input_rate
41
+ indices = torch.arange(new_length) * (
42
+ self.input_rate / self.output_rate
43
+ )
44
+ round_down = wav[:, indices.long()]
45
+ round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
46
+ output = round_down * (1.0 - indices.fmod(1.0)).unsqueeze(
47
+ 0
48
+ ) + round_up * indices.fmod(1.0).unsqueeze(0)
49
+ return output
50
+
51
+
52
+ model = lightning_module.BaselineLightningModule.load_from_checkpoint(
53
+ "./src/epoch=3-step=7459.ckpt"
54
+ ).eval()
55
+
56
+
57
+ def calc_mos(audio_path, ref):
58
+ wav, sr = torchaudio.load(audio_path)
59
+ osr = 16_000
60
+ batch = wav.unsqueeze(0).repeat(10, 1, 1)
61
+ csr = ChangeSampleRate(sr, osr)
62
+ out_wavs = csr(wav)
63
+ # ASR
64
+ trans = p(audio_path)["text"]
65
+ # WER
66
+ wer = jiwer.wer(
67
+ ref,
68
+ trans,
69
+ truth_transform=transformation,
70
+ hypothesis_transform=transformation,
71
+ )
72
+ # MOS
73
+ batch = {
74
+ "wav": out_wavs,
75
+ "domains": torch.tensor([0]),
76
+ "judge_id": torch.tensor([288]),
77
+ }
78
+ with torch.no_grad():
79
+ output = model(batch)
80
+ predic_mos = output.mean(dim=1).squeeze().detach().numpy() * 2 + 3
81
+ # Phonemes per minute (PPM)
82
+ with torch.no_grad():
83
+ logits = phoneme_model(out_wavs).logits
84
+ phone_predicted_ids = torch.argmax(logits, dim=-1)
85
+ phone_transcription = processor.batch_decode(phone_predicted_ids)
86
+ lst_phonemes = phone_transcription[0].split(" ")
87
+ wav_vad = torchaudio.functional.vad(wav, sample_rate=sr)
88
+ ppm = len(lst_phonemes) / (wav_vad.shape[-1] / sr) * 60
89
+ # if float(predic_mos) >= 3.0:
90
+ # torchaudio.save("good.wav", wav,sr)
91
+
92
+ return predic_mos, trans, wer, phone_transcription, ppm
93
+
94
+ if __name__ == "__main__":
95
+ # Argparse
96
+ parser = argparse.ArgumentParser(
97
+ prog="get_ref_PPM",
98
+ description="Generate Phoneme per Minute (and Voice/Unvoice plot)",
99
+ epilog="",
100
+ )
101
+ parser.add_argument(
102
+ "--tag",
103
+ type=str,
104
+ default=None,
105
+ required=False,
106
+ help="ID tag for output *.csv",
107
+ )
108
+
109
+ parser.add_argument("--ref_txt", type=str, required=True, help="Reference TXT")
110
+ parser.add_argument(
111
+ "--ref_wavs", type=str, required=True, help="Reference WAVs"
112
+ )
113
+
114
+ parser.add_argument(
115
+ "--output_dir",
116
+ type=str,
117
+ required=True,
118
+ help="Output Directory for *.csv",
119
+ )
120
+ parser.add_argument(
121
+ "--to_config",
122
+ choices=["True", "False"],
123
+ default="False",
124
+ help="Generating Config from .txt and wavs/*wav",
125
+ )
126
+
127
+ parser.add_argument(
128
+ "--UV_flag",
129
+ choices=["True", "False"],
130
+ default="False",
131
+ help="Toggle for U/V plot",
132
+ )
133
+ parser.add_argument(
134
+ "--UV_thre", type=float, default=40, help="U/V threshold dB"
135
+ )
136
+ args = parser.parse_args()
137
+
138
+ refs = np.loadtxt(args.ref_txt, delimiter="\n", dtype="str")
139
+ refs_ids = [x.split()[0] for x in refs]
140
+ refs_txt = [" ".join(x.split()[1:]) for x in refs]
141
+ ref_wavs = [str(x) for x in sorted(Path(args.ref_wavs).glob("**/*.wav"))]
142
+ # pdb.set_trace()
143
+ try:
144
+ len(refs) == len(ref_wavs)
145
+ except ValueError:
146
+ print("Error: Text and Wavs don't match")
147
+ exit()
148
+
149
+ # ASR part
150
+ p = pipeline("automatic-speech-recognition")
151
+
152
+ # WER part
153
+ transformation = jiwer.Compose(
154
+ [
155
+ jiwer.ToLowerCase(),
156
+ jiwer.RemoveWhiteSpace(replace_by_space=True),
157
+ jiwer.RemoveMultipleSpaces(),
158
+ jiwer.ReduceToListOfListOfWords(word_delimiter=" "),
159
+ ]
160
+ )
161
+
162
+ # WPM part
163
+ processor = Wav2Vec2Processor.from_pretrained(
164
+ "facebook/wav2vec2-xlsr-53-espeak-cv-ft"
165
+ )
166
+ phoneme_model = Wav2Vec2ForCTC.from_pretrained(
167
+ "facebook/wav2vec2-xlsr-53-espeak-cv-ft"
168
+ )
169
+ # phoneme_model = pipeline(model="facebook/wav2vec2-xlsr-53-espeak-cv-ft")
170
+
171
+ description = """
172
+ MOS prediction demo using UTMOS-strong w/o phoneme encoder model, \
173
+ which is trained on the main track dataset.
174
+ This demo only accepts .wav format. Best at 16 kHz sampling rate.
175
+
176
+ Paper is available [here](https://arxiv.org/abs/2204.02152)
177
+
178
+ Add ASR based on wav2vec-960, currently only English available.
179
+ Add WER interface.
180
+ """
181
+
182
+ referance_id = gr.Textbox(
183
+ value="ID", placeholder="Utter ID", label="Reference_ID"
184
+ )
185
+ referance_textbox = gr.Textbox(
186
+ value="", placeholder="Input reference here", label="Reference"
187
+ )
188
+ # Set up interface
189
+ result = []
190
+ result.append("id, pred_mos, trans, wer, pred_phone, ppm")
191
+
192
+ if args.UV_flag == "False":
193
+ for id, x, y in track(
194
+ zip(refs_ids, ref_wavs, refs_txt),
195
+ total=len(refs_ids),
196
+ description="Loading references information",
197
+ ):
198
+ predic_mos, trans, wer, phone_transcription, ppm = calc_mos(x, y)
199
+ record = ",".join(
200
+ [
201
+ id,
202
+ str(predic_mos),
203
+ str(trans),
204
+ str(wer),
205
+ str(phone_transcription),
206
+ str(ppm),
207
+ ]
208
+ )
209
+ result.append(record)
210
+
211
+ elif args.UV_flag == "True":
212
+ fig_tardir = Path(args.ref_wavs) / Path("PPM_figs")
213
+ Path.mkdir(Path(args.ref_wavs) / Path("PPM_figs"), exist_ok=True)
214
+
215
+ for id, x, y in track(
216
+ zip(refs_ids, ref_wavs, refs_txt),
217
+ total=len(refs_ids),
218
+ description="Loading references information",
219
+ ):
220
+ # UV ploting
221
+ wav, sr = torchaudio.load(x)
222
+ wav_vad = torchaudio.functional.vad(wav, sample_rate=sr)
223
+ a_h, p_h = get_speech_interval(wav_vad.numpy(), db=args.UV_thre)
224
+ fig_h = plot_UV(wav_vad.numpy().squeeze(), a_h, sr=sr)
225
+ fig_h.savefig(Path(fig_tardir) / Path(id + ".png"), dpi=200)
226
+ # Acoustic calculation
227
+ predic_mos, trans, wer, phone_transcription, ppm = calc_mos(x, y)
228
+ record = ",".join(
229
+ [
230
+ id,
231
+ str(predic_mos),
232
+ str(trans),
233
+ str(wer),
234
+ str(phone_transcription),
235
+ str(ppm),
236
+ ]
237
+ )
238
+ result.append(record)
239
+ # Output
240
+ if args.tag == None:
241
+ args.tag = Path(args.ref_wavs).stem
242
+ # Make output_dir
243
+ # pdb.set_trace()
244
+ Path.mkdir(Path(args.output_dir), exist_ok=True)
245
+ # pdb.set_trace()
246
+ with open("%s/%s.csv" % (args.output_dir, args.tag), "w") as f:
247
+ print("\n".join(result), file=f)
248
+
249
+ # Generating config
250
+ if args.to_config == "True":
251
+ config_dict = {
252
+ "exp_id": args.tag,
253
+ "ref_txt": args.ref_txt,
254
+ "ref_feature": "%s/%s.csv" % (args.output_dir, args.tag),
255
+ "ref_wavs": args.ref_wavs,
256
+ "thre": {
257
+ "minppm": 100,
258
+ "maxppm": 100,
259
+ "WER": 0.1,
260
+ "AUTOMOS": 4.0,
261
+ },
262
+ "auth": {"username": None, "password": None},
263
+ }
264
+ with open("./config/%s.yaml" % args.tag, "w") as config_f:
265
+ rprint("Dumping as config ./config/%s.yaml" % args.tag)
266
+ rprint(config_dict)
267
+ yaml.dump(config_dict, stream=config_f)
268
+ rprint("Change parameter ./config/%s.yaml if necessary" % args.tag)
269
+ print("Reference Dumping Finished")
local/new_whisper_fine_tuning.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fine_tuning_dir = "/fine_tuned/whipser_medium_en_PAL300_step25_step2_VCTK/checkpoint-400"
2
+
3
+ """
4
+ TODO:
5
+ + [ ] Data load
6
+ + [ ] Train / Test / Dev spilt
7
+ + [ ] Train / Test Phase
8
+ + [ ] Logging with Train / Dev / Test Loss
9
+ + [ ] Evalutation metrics
10
+ """
11
+ import pdb
12
+ import string
13
+ from pathlib import Path
14
+
15
+ import evaluate
16
+ import librosa
17
+ import torch
18
+ import torch.nn as nn
19
+ from datasets import Dataset, concatenate_datasets, load_dataset
20
+ from transformers import AutoProcessor
21
+
22
+ wer = evaluate.load("wer")
23
+ torch.cuda.set_device("cuda:0")
24
+
25
+ audio_dir = "./data/Patient_sil_trim_16k_normed_5_snr_40"
26
+ healthy_dir = "./data/Healthy"
27
+ Fary_PAL_30 = "./data/Fary_PAL_p326_20230110_30"
28
+ John_p326 = "./data/John_p326/output"
29
+ John_video = "./data/20230103_video"
30
+
31
+ ## train
32
+ p326_300_dir = "./data/John_p326_large"
33
+ P1tony_arthur = "data/Participant1_Tony_Recording/CLEAN_SENTENCES/SCRIPTED/Arthur_the_Rat/PAL"
34
+ P1tony_rainbow = "data/Participant1_Tony_Recording/CLEAN_SENTENCES/SCRIPTED/Rainbow_Passage/Laronix"
35
+
36
+ P1tony = "data/Participant1_Tony_Recording/CLEAN_SENTENCES/CONVERSATIONAL/PAL"
37
+
38
+ P4Negel = 'data/4_negal_152_clean_all'
39
+
40
+ def dataclean(example):
41
+ if example["audio"]["sampling_rate"] != 16000:
42
+ resampled_audio = librosa.resample(
43
+ y=example["audio"]["array"],
44
+ orig_sr=example["audio"]["sampling_rate"],
45
+ target_sr=16000,
46
+ )
47
+
48
+ return {
49
+ "audio": {
50
+ "path": example["audio"]["path"],
51
+ "array": resampled_audio,
52
+ "sampling_rate": 16000,
53
+ },
54
+ "transcription": example["transcription"]
55
+ .upper()
56
+ .translate(str.maketrans("", "", string.punctuation)),
57
+ }
58
+ else:
59
+ return {
60
+ "transcription": example["transcription"]
61
+ .upper()
62
+ .translate(str.maketrans("", "", string.punctuation))
63
+ }
64
+
65
+
66
+
67
+ P1tony_dataset = load_dataset("audiofolder", data_dir=P1tony, split="train")
68
+ P1tony_dataset = P1tony_dataset.map(dataclean)
69
+
70
+ P1tony_scripted1 = load_dataset(
71
+ "audiofolder", data_dir=P1tony_rainbow, split="train"
72
+ )
73
+ P1tony_scripted2 = load_dataset(
74
+ "audiofolder", data_dir=P1tony_arthur, split="train"
75
+ )
76
+ P1tony_scripted1 = P1tony_scripted1.map(dataclean)
77
+ P1tony_scripted2 = P1tony_scripted2.map(dataclean)
78
+ P1tony_scripted = concatenate_datasets([P1tony_scripted1, P1tony_scripted2])
79
+
80
+ class ChangeSampleRate(nn.Module):
81
+ def __init__(self, input_rate: int, output_rate: int):
82
+ super().__init__()
83
+ self.output_rate = output_rate
84
+ self.input_rate = input_rate
85
+
86
+ def forward(self, wav: torch.tensor) -> torch.tensor:
87
+ # Only accepts 1-channel waveform input
88
+ wav = wav.view(wav.size(0), -1)
89
+ new_length = wav.size(-1) * self.output_rate // self.input_rate
90
+ indices = torch.arange(new_length) * (
91
+ self.input_rate / self.output_rate
92
+ )
93
+ round_down = wav[:, indices.long()]
94
+ round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
95
+ output = round_down * (1.0 - indices.fmod(1.0)).unsqueeze(
96
+ 0
97
+ ) + round_up * indices.fmod(1.0).unsqueeze(0)
98
+ return output
99
+
100
+ # resample and clean text data
101
+ def dataclean(example):
102
+ # pdb.set_trace()
103
+ if example["audio"]["sampling_rate"] != 16000:
104
+ resampled_audio = librosa.resample(
105
+ y=example["audio"]["array"],
106
+ orig_sr=example["audio"]["sampling_rate"],
107
+ target_sr=16000,
108
+ )
109
+
110
+ return {
111
+ "audio": {
112
+ "path": example["audio"]["path"],
113
+ "array": resampled_audio,
114
+ "sampling_rate": 16000,
115
+ },
116
+ "transcription": example["transcription"]
117
+ .upper()
118
+ .translate(str.maketrans("", "", string.punctuation)),
119
+ }
120
+ else:
121
+ return {
122
+ "transcription": example["transcription"]
123
+ .upper()
124
+ .translate(str.maketrans("", "", string.punctuation))
125
+ }
126
+
127
+
128
+ # processor = AutoFeatureExtractor.from_pretrained(
129
+ # "facebook/wav2vec2-base-960h"
130
+ # )
131
+ processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
132
+
133
+ def prepare_dataset(batch):
134
+ audio = batch["audio"]
135
+ batch = processor(
136
+ audio["array"],
137
+ sampling_rate=audio["sampling_rate"],
138
+ text=batch["transcription"],
139
+ )
140
+ batch["input_length"] = len(batch["input_values"][0])
141
+ return batch
142
+
143
+ src_dataset = load_dataset("audiofolder", data_dir=audio_dir, split="train")
144
+ src_dataset = src_dataset.map(dataclean)
145
+ p326_300_dataset = load_dataset(
146
+ "audiofolder", data_dir=p326_300_dir, split="train"
147
+ )
148
+ p326_300_dataset = p326_300_dataset.map(dataclean)
149
+
150
+ P4Negel_dataset = load_dataset("audiofolder", data_dir=P4Negel, split="train")
151
+ P4Negel_dataset = P4Negel_dataset.map(dataclean)
152
+
153
+ healthy_test_dataset = load_dataset(
154
+ "audiofolder", data_dir=healthy_dir, split="train"
155
+ )
156
+ healthy_test_dataset = healthy_test_dataset.map(dataclean)
157
+
158
+ Fary_PAL_test_dataset = load_dataset(
159
+ "audiofolder", data_dir=Fary_PAL_30, split="train"
160
+ )
161
+ Fary_PAL_test_dataset = Fary_PAL_test_dataset.map(dataclean)
162
+
163
+ John_p326_test_dataset = load_dataset(
164
+ "audiofolder", data_dir=John_p326, split="train"
165
+ )
166
+ John_p326_test_dataset = John_p326_test_dataset.map(dataclean)
167
+
168
+ John_video_test_dataset = load_dataset(
169
+ "audiofolder", data_dir=John_video, split="train"
170
+ )
171
+ John_video_test_dataset = John_video_test_dataset.map(dataclean)
172
+
173
+
174
+ def train_dev_test_split(
175
+ dataset: Dataset, dev_rate=0.1, test_rate=0.1, seed=1
176
+ ):
177
+ """
178
+ input: dataset
179
+ dev_rate,
180
+ test_rate
181
+ seed
182
+ -------
183
+ Output:
184
+ dataset_dict{"train", "dev", "test"}
185
+ """
186
+ train_dev_test = dataset.train_test_split(test_size=test_rate, seed=seed)
187
+ test = train_dev_test["test"]
188
+ train_dev = train_dev_test["train"]
189
+
190
+ # pdb.set_trace()
191
+ if len(train_dev) <= int(len(dataset) * dev_rate):
192
+ train = Dataset.from_dict({"audio": [], "transcription": []})
193
+ dev = train_dev
194
+ else:
195
+ train_dev = train_dev.train_test_split(
196
+ test_size=int(len(dataset) * dev_rate), seed=seed
197
+ )
198
+ train = train_dev["train"]
199
+ dev = train_dev["test"]
200
+ return train, dev, test
201
+
202
+ P1tony_train, P1tony_dev, P1tony_test = train_dev_test_split(
203
+ P1tony_dataset, dev_rate=0.5, test_rate=0.5, seed=1
204
+ )
205
+ P1tony_train_ = concatenate_datasets([P1tony_train, P1tony_scripted])
206
+
207
+ # train_dev / test
208
+ ds = src_dataset.train_test_split(test_size=0.1, seed=1)
209
+
210
+ # dataset_libri = load_dataset(
211
+ # "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
212
+ # )
213
+
214
+ train_dev = ds["train"]
215
+ # train / dev
216
+ train_dev = train_dev.train_test_split(
217
+ test_size=int(len(src_dataset) * 0.1), seed=1
218
+ )
219
+
220
+ # Tony
221
+ Tony_train = P1tony_train_
222
+ Tony_dev = P1tony_dev
223
+ Tony_test = P1tony_test
224
+
225
+ # John
226
+ John_train, John_dev, John_test = train_dev_test_split(p326_300_dataset, dev_rate=0.1, test_rate=0.1)
227
+ # Negel
228
+ Negel_train, Negel_dev, Negel_test = train_dev_test_split(P4Negel_dataset, dev_rate=0.1, test_rate=0.1)
229
+
230
+ # train/dev/test
231
+ train = train_dev["train"]
232
+ test = ds["test"]
233
+ dev = train_dev["test"]
234
+
235
+ # combined
236
+ combine_train = concatenate_datasets([train, Tony_train, John_train, Negel_train])
237
+ conbine_dev = concatenate_datasets([dev, Tony_dev, John_dev, Negel_dev])
238
+ conbine_test = concatenate_datasets([test, Tony_test, John_test, Negel_test])
239
+
240
+ # encoded_train = combine_train.map(prepare_dataset, num_proc=4)
241
+ # encoded_dev = conbine_dev.map(prepare_dataset, num_proc=4)
242
+ # encoded_test = conbine_test.map(prepare_dataset, num_proc=4)
243
+
244
+ # # extra_test
245
+ # encoded_Fary = Fary_PAL_test_dataset.map(prepare_dataset, num_proc=4)
246
+ # encoded_healthy = healthy_test_dataset.map(prepare_dataset, num_proc=4)
247
+
248
+ # encoded_ori_test = test.map(prepare_dataset, num_proc=4)
249
+ # encoded_Tony_test = Tony_test.map(prepare_dataset, num_proc=4)
250
+ # encoded_John_test = John_test.map(prepare_dataset, num_proc=4)
251
+ # encoded_Negel_test = Negel_test.map(prepare_dataset, num_proc=4)
252
+
253
+ # encoded_train = train.map(prepare_dataset, num_proc=4)
254
+ # encoded_dev = dev.map(prepare_dataset, num_proc=4)
255
+ # p326_encoded_train = p326_300_dataset.map(prepare_dataset, num_proc=4)
256
+
257
+ # combine large p326 in to training set
258
+ # encoded_train = concatenate_datasets([encoded_train, p326_encoded_train])
259
+
260
+ # encoded_John_p326 = John_p326_test_dataset.map(prepare_dataset, num_proc=4)
261
+ # encoded_John_video = John_video_test_dataset.map(prepare_dataset, num_proc=4)
262
+
263
+ # pdb.set_trace()
264
+ import numpy as np
265
+
266
+ WER = evaluate.load("wer")
267
+
268
+ ## Whisper decoding
269
+
270
+ from transformers import (Seq2SeqTrainer, Seq2SeqTrainingArguments,
271
+ WhisperFeatureExtractor,
272
+ WhisperForConditionalGeneration, WhisperModel,
273
+ WhisperProcessor, WhisperTokenizer)
274
+
275
+ processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
276
+ # model = WhisperForConditionalGeneration.from_pretrained(
277
+ # "./fine_tuned/whipser_medium_en_PAL300_step25_step2_VCTK/checkpoint-400",
278
+ # use_auth_token=True,
279
+ # ).to("cuda:0")
280
+ model = WhisperForConditionalGeneration.from_pretrained(
281
+ "openai/whisper-medium",
282
+ ).to("cuda:0")
283
+ tokenizer = WhisperTokenizer.from_pretrained(
284
+ "openai/whisper-medium", language="English", task="transcribe"
285
+ )
286
+
287
+ from pathlib import Path
288
+
289
+ id = Path(fine_tuning_dir).stem
290
+ pdb.set_trace()
291
+ tokenizer.push_to_hub("KevinGeng/%s" % id)
292
+ # import pdb
293
+ feature_extractor = WhisperFeatureExtractor.from_pretrained(
294
+ "openai/whisper-medium"
295
+ )
296
+
297
+ def whisper_prepare_dataset(batch):
298
+ # load and resample audio data from 48 to 16kHz
299
+ audio = batch["audio"]
300
+
301
+ # compute log-Mel input features from input audio array
302
+ batch["input_features"] = feature_extractor(
303
+ audio["array"], sampling_rate=audio["sampling_rate"]
304
+ ).input_features[0]
305
+
306
+ # encode target text to label ids
307
+ batch["labels"] = tokenizer(batch["transcription"]).input_ids
308
+ return batch
309
+
310
+ torch.cuda.empty_cache()
311
+
312
+
313
+ def my_map_to_pred(batch):
314
+ # pdb.set_trace()
315
+ audio = batch["audio"]
316
+ input_features = processor(
317
+ audio["array"],
318
+ sampling_rate=audio["sampling_rate"],
319
+ return_tensors="pt",
320
+ ).input_features
321
+ # batch["reference"] = whisper_processor.tokenizer._normalize(batch['text'])
322
+ batch["reference"] = processor.tokenizer._normalize(batch["transcription"])
323
+
324
+ with torch.no_grad():
325
+ # predicted_ids = whisper_model.generate(input_features.to("cuda"))[0]
326
+ predicted_ids = model.generate(input_features.to("cuda"))[0]
327
+ transcription = model.decode(predicted_ids)
328
+ batch["prediction"] = model.tokenizer._normalize(transcription)
329
+ return batch
330
+
331
+
332
+ from dataclasses import dataclass
333
+ from typing import Any, Dict, List, Union
334
+
335
+ import torch
336
+
337
+
338
+ @dataclass
339
+ class DataCollatorSpeechSeq2SeqWithPadding:
340
+ processor: Any
341
+
342
+ def __call__(
343
+ self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
344
+ ) -> Dict[str, torch.Tensor]:
345
+ # split inputs and labels since they have to be of different lengths and need different padding methods
346
+ # first treat the audio inputs by simply returning torch tensors
347
+ input_features = [
348
+ {"input_features": feature["input_features"]}
349
+ for feature in features
350
+ ]
351
+ batch = self.processor.feature_extractor.pad(
352
+ input_features, return_tensors="pt"
353
+ )
354
+
355
+ # get the tokenized label sequences
356
+ label_features = [
357
+ {"input_ids": feature["labels"]} for feature in features
358
+ ]
359
+ # pad the labels to max length
360
+ labels_batch = self.processor.tokenizer.pad(
361
+ label_features, return_tensors="pt"
362
+ )
363
+
364
+ # replace padding with -100 to ignore loss correctly
365
+ labels = labels_batch["input_ids"].masked_fill(
366
+ labels_batch.attention_mask.ne(1), -100
367
+ )
368
+
369
+ # if bos token is appended in previous tokenization step,
370
+ # cut bos token here as it's append later anyways
371
+ if (
372
+ (labels[:, 0] == self.processor.tokenizer.bos_token_id)
373
+ .all()
374
+ .cpu()
375
+ .item()
376
+ ):
377
+ labels = labels[:, 1:]
378
+
379
+ batch["labels"] = labels
380
+
381
+ return batch
382
+
383
+
384
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
385
+
386
+
387
+ def compute_metrics(pred):
388
+ pred_ids = pred.predictions
389
+ label_ids = pred.label_ids
390
+
391
+ # replace -100 with the pad_token_id
392
+ label_ids[label_ids == -100] = tokenizer.pad_token_id
393
+
394
+ # we do not want to group tokens when computing the metrics
395
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
396
+ label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
397
+
398
+ wer = 100 * WER.compute(predictions=pred_str, references=label_str)
399
+
400
+ return {"wer": wer}
401
+
402
+ encoded_train = combine_train.map(whisper_prepare_dataset, num_proc=4)
403
+ encoded_dev = conbine_dev.map(whisper_prepare_dataset, num_proc=4)
404
+ encoded_test = conbine_test.map(whisper_prepare_dataset, num_proc=4)
405
+
406
+ # extra_test
407
+
408
+ encoded_ori_test = test.map(whisper_prepare_dataset, num_proc=4)
409
+ encoded_Tony_test = Tony_test.map(whisper_prepare_dataset, num_proc=4)
410
+ encoded_John_test = John_test.map(whisper_prepare_dataset, num_proc=4)
411
+ encoded_Negel_test = Negel_test.map(whisper_prepare_dataset, num_proc=4)
412
+
413
+ encoded_Fary = Fary_PAL_test_dataset.map(whisper_prepare_dataset, num_proc=4)
414
+ encoded_healthy = healthy_test_dataset.map(whisper_prepare_dataset, num_proc=4)
415
+
416
+ torch.cuda.empty_cache()
417
+
418
+ training_args = Seq2SeqTrainingArguments(
419
+ output_dir=fine_tuning_dir, # change to a repo name of your choice
420
+ per_device_train_batch_size=8,
421
+ gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
422
+ learning_rate=1e-5,
423
+ warmup_steps=50,
424
+ max_steps=1000,
425
+ gradient_checkpointing=True,
426
+ fp16=True,
427
+ evaluation_strategy="steps",
428
+ save_strategy="steps",
429
+ per_device_eval_batch_size=8,
430
+ predict_with_generate=True,
431
+ generation_max_length=512,
432
+ save_steps=20,
433
+ eval_steps=20,
434
+ logging_steps=10,
435
+ report_to=["tensorboard"],
436
+ load_best_model_at_end=True,
437
+ metric_for_best_model="wer",
438
+ greater_is_better=False,
439
+ save_total_limit=5,
440
+ push_to_hub=False,
441
+ )
442
+ from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
443
+
444
+ trainer = Seq2SeqTrainer(
445
+ args=training_args,
446
+ model=model,
447
+ train_dataset=Negel_train,
448
+ eval_dataset=Negel_dev,
449
+ data_collator=data_collator,
450
+ compute_metrics=compute_metrics,
451
+ tokenizer=processor.feature_extractor,
452
+ callbacks=[EarlyStoppingCallback(early_stopping_patience=10)],
453
+ )
454
+ # callbacks=[EvalLoggingCallback()]
455
+ pdb.set_trace()
456
+
457
+ before_result_dict = {
458
+ "Ori_Test": trainer.evaluate(encoded_ori_test),
459
+ "Tony_Test": trainer.evaluate(encoded_Tony_test),
460
+ "John_Test": trainer.evaluate(encoded_John_test),
461
+ "Negel_Test": trainer.evaluate(encoded_Negel_test),
462
+ "Zeroshot_Fary_Test": trainer.evaluate(encoded_Fary),
463
+ "Healthy_Test": trainer.evaluate(encoded_healthy),
464
+ }
465
+
466
+ print(before_result_dict)
467
+ trainer.train()
468
+
469
+ pdb.set_trace()
470
+ result_dict = {
471
+ "Ori_Test": trainer.evaluate(encoded_ori_test),
472
+ "Tony_Test": trainer.evaluate(encoded_Tony_test),
473
+ "John_Test": trainer.evaluate(encoded_John_test),
474
+ "Negel_Test": trainer.evaluate(encoded_Negel_test),
475
+ "Zeroshot_Fary_Test": trainer.evaluate(encoded_Fary),
476
+ "Healthy_Test": trainer.evaluate(encoded_healthy),
477
+ }
478
+
479
+ pdb.set_trace()
480
+ # Evaluation
481
+ model.push_to_hub("KevinGeng/%s" % id)
local/new_whisper_fine_tuning_decode.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fine_tuning_dir = "/home/kevingeng/Disk2/laronix/laronix_automos/fine_tuned/SSD/model/Org_Tony_John_Negel_Train_557_Dev_79_Test_81/checkpoint-160"
2
+ """
3
+ TODO:
4
+ + [ ] Data load
5
+ + [ ] Train / Test / Dev spilt
6
+ + [ ] Train / Test Phase
7
+ + [ ] Logging with Train / Dev / Test Loss
8
+ + [ ] Evalutation metrics
9
+ """
10
+ import pdb
11
+ import string
12
+ from pathlib import Path
13
+
14
+ import evaluate
15
+ import librosa
16
+ import torch
17
+ import torch.nn as nn
18
+ from datasets import Dataset, concatenate_datasets, load_dataset
19
+ from transformers import AutoProcessor
20
+
21
+ wer = evaluate.load("wer")
22
+ torch.cuda.set_device("cuda:0")
23
+
24
+ audio_dir = "./data/Patient_sil_trim_16k_normed_5_snr_40"
25
+ healthy_dir = "./data/Healthy"
26
+ Fary_PAL_30 = "./data/Fary_PAL_p326_20230110_30"
27
+ John_p326 = "./data/John_p326/output"
28
+ John_video = "./data/20230103_video"
29
+
30
+ ## train
31
+ p326_300_dir = "./data/John_p326_large"
32
+ P1tony_arthur = "data/Participant1_Tony_Recording/CLEAN_SENTENCES/SCRIPTED/Arthur_the_Rat/PAL"
33
+ P1tony_rainbow = "data/Participant1_Tony_Recording/CLEAN_SENTENCES/SCRIPTED/Rainbow_Passage/Laronix"
34
+
35
+ P1tony = "data/Participant1_Tony_Recording/CLEAN_SENTENCES/CONVERSATIONAL/PAL"
36
+
37
+ P4Negel = 'data/4_negal_152_clean_all'
38
+
39
+ def dataclean(example):
40
+ if example["audio"]["sampling_rate"] != 16000:
41
+ resampled_audio = librosa.resample(
42
+ y=example["audio"]["array"],
43
+ orig_sr=example["audio"]["sampling_rate"],
44
+ target_sr=16000,
45
+ )
46
+
47
+ return {
48
+ "audio": {
49
+ "path": example["audio"]["path"],
50
+ "array": resampled_audio,
51
+ "sampling_rate": 16000,
52
+ },
53
+ "transcription": example["transcription"]
54
+ .upper()
55
+ .translate(str.maketrans("", "", string.punctuation)),
56
+ }
57
+ else:
58
+ return {
59
+ "transcription": example["transcription"]
60
+ .upper()
61
+ .translate(str.maketrans("", "", string.punctuation))
62
+ }
63
+
64
+ P1tony_dataset = load_dataset("audiofolder", data_dir=P1tony, split="train")
65
+ P1tony_dataset = P1tony_dataset.map(dataclean)
66
+
67
+ P1tony_scripted1 = load_dataset(
68
+ "audiofolder", data_dir=P1tony_rainbow, split="train"
69
+ )
70
+ P1tony_scripted2 = load_dataset(
71
+ "audiofolder", data_dir=P1tony_arthur, split="train"
72
+ )
73
+ P1tony_scripted1 = P1tony_scripted1.map(dataclean)
74
+ P1tony_scripted2 = P1tony_scripted2.map(dataclean)
75
+ P1tony_scripted = concatenate_datasets([P1tony_scripted1, P1tony_scripted2])
76
+
77
+ class ChangeSampleRate(nn.Module):
78
+ def __init__(self, input_rate: int, output_rate: int):
79
+ super().__init__()
80
+ self.output_rate = output_rate
81
+ self.input_rate = input_rate
82
+
83
+ def forward(self, wav: torch.tensor) -> torch.tensor:
84
+ # Only accepts 1-channel waveform input
85
+ wav = wav.view(wav.size(0), -1)
86
+ new_length = wav.size(-1) * self.output_rate // self.input_rate
87
+ indices = torch.arange(new_length) * (
88
+ self.input_rate / self.output_rate
89
+ )
90
+ round_down = wav[:, indices.long()]
91
+ round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
92
+ output = round_down * (1.0 - indices.fmod(1.0)).unsqueeze(
93
+ 0
94
+ ) + round_up * indices.fmod(1.0).unsqueeze(0)
95
+ return output
96
+
97
+ # resample and clean text data
98
+ def dataclean(example):
99
+ # pdb.set_trace()
100
+ if example["audio"]["sampling_rate"] != 16000:
101
+ resampled_audio = librosa.resample(
102
+ y=example["audio"]["array"],
103
+ orig_sr=example["audio"]["sampling_rate"],
104
+ target_sr=16000,
105
+ )
106
+
107
+ return {
108
+ "audio": {
109
+ "path": example["audio"]["path"],
110
+ "array": resampled_audio,
111
+ "sampling_rate": 16000,
112
+ },
113
+ "transcription": example["transcription"]
114
+ .upper()
115
+ .translate(str.maketrans("", "", string.punctuation)),
116
+ }
117
+ else:
118
+ return {
119
+ "transcription": example["transcription"]
120
+ .upper()
121
+ .translate(str.maketrans("", "", string.punctuation))
122
+ }
123
+
124
+
125
+ # processor = AutoFeatureExtractor.from_pretrained(
126
+ # "facebook/wav2vec2-base-960h"
127
+ # )
128
+ processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
129
+
130
+ def prepare_dataset(batch):
131
+ audio = batch["audio"]
132
+ batch = processor(
133
+ audio["array"],
134
+ sampling_rate=audio["sampling_rate"],
135
+ text=batch["transcription"],
136
+ )
137
+ batch["input_length"] = len(batch["input_values"][0])
138
+ return batch
139
+
140
+ src_dataset = load_dataset("audiofolder", data_dir=audio_dir, split="train")
141
+ src_dataset = src_dataset.map(dataclean)
142
+ p326_300_dataset = load_dataset(
143
+ "audiofolder", data_dir=p326_300_dir, split="train"
144
+ )
145
+ p326_300_dataset = p326_300_dataset.map(dataclean)
146
+
147
+ P4Negel_dataset = load_dataset("audiofolder", data_dir=P4Negel, split="train")
148
+ P4Negel_dataset = P4Negel_dataset.map(dataclean)
149
+
150
+ healthy_test_dataset = load_dataset(
151
+ "audiofolder", data_dir=healthy_dir, split="train"
152
+ )
153
+ healthy_test_dataset = healthy_test_dataset.map(dataclean)
154
+
155
+ Fary_PAL_test_dataset = load_dataset(
156
+ "audiofolder", data_dir=Fary_PAL_30, split="train"
157
+ )
158
+ Fary_PAL_test_dataset = Fary_PAL_test_dataset.map(dataclean)
159
+
160
+ John_p326_test_dataset = load_dataset(
161
+ "audiofolder", data_dir=John_p326, split="train"
162
+ )
163
+ John_p326_test_dataset = John_p326_test_dataset.map(dataclean)
164
+
165
+ John_video_test_dataset = load_dataset(
166
+ "audiofolder", data_dir=John_video, split="train"
167
+ )
168
+ John_video_test_dataset = John_video_test_dataset.map(dataclean)
169
+
170
+
171
+ def train_dev_test_split(
172
+ dataset: Dataset, dev_rate=0.1, test_rate=0.1, seed=1
173
+ ):
174
+ """
175
+ input: dataset
176
+ dev_rate,
177
+ test_rate
178
+ seed
179
+ -------
180
+ Output:
181
+ dataset_dict{"train", "dev", "test"}
182
+ """
183
+ train_dev_test = dataset.train_test_split(test_size=test_rate, seed=seed)
184
+ test = train_dev_test["test"]
185
+ train_dev = train_dev_test["train"]
186
+
187
+ # pdb.set_trace()
188
+ if len(train_dev) <= int(len(dataset) * dev_rate):
189
+ train = Dataset.from_dict({"audio": [], "transcription": []})
190
+ dev = train_dev
191
+ else:
192
+ train_dev = train_dev.train_test_split(
193
+ test_size=int(len(dataset) * dev_rate), seed=seed
194
+ )
195
+ train = train_dev["train"]
196
+ dev = train_dev["test"]
197
+ return train, dev, test
198
+
199
+ P1tony_train, P1tony_dev, P1tony_test = train_dev_test_split(
200
+ P1tony_dataset, dev_rate=0.5, test_rate=0.5, seed=1
201
+ )
202
+ P1tony_train_ = concatenate_datasets([P1tony_train, P1tony_scripted])
203
+
204
+ # train_dev / test
205
+ ds = src_dataset.train_test_split(test_size=0.1, seed=1)
206
+
207
+ # dataset_libri = load_dataset(
208
+ # "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
209
+ # )
210
+
211
+ train_dev = ds["train"]
212
+ # train / dev
213
+ train_dev = train_dev.train_test_split(
214
+ test_size=int(len(src_dataset) * 0.1), seed=1
215
+ )
216
+
217
+ # Tony
218
+ Tony_train = P1tony_train_
219
+ Tony_dev = P1tony_dev
220
+ Tony_test = P1tony_test
221
+
222
+ # John
223
+ John_train, John_dev, John_test = train_dev_test_split(p326_300_dataset, dev_rate=0.1, test_rate=0.1)
224
+ # Negel
225
+ Negel_train, Negel_dev, Negel_test = train_dev_test_split(P4Negel_dataset, dev_rate=0.1, test_rate=0.1)
226
+
227
+ # train/dev/test
228
+ train = train_dev["train"]
229
+ test = ds["test"]
230
+ dev = train_dev["test"]
231
+
232
+ # combined
233
+ combine_train = concatenate_datasets([train, Tony_train, John_train, Negel_train])
234
+ conbine_dev = concatenate_datasets([dev, Tony_dev, John_dev, Negel_dev])
235
+ conbine_test = concatenate_datasets([test, Tony_test, John_test, Negel_test])
236
+
237
+ # encoded_train = combine_train.map(prepare_dataset, num_proc=4)
238
+ # encoded_dev = conbine_dev.map(prepare_dataset, num_proc=4)
239
+ # encoded_test = conbine_test.map(prepare_dataset, num_proc=4)
240
+
241
+ # # extra_test
242
+ # encoded_Fary = Fary_PAL_test_dataset.map(prepare_dataset, num_proc=4)
243
+ # encoded_healthy = healthy_test_dataset.map(prepare_dataset, num_proc=4)
244
+
245
+ # encoded_ori_test = test.map(prepare_dataset, num_proc=4)
246
+ # encoded_Tony_test = Tony_test.map(prepare_dataset, num_proc=4)
247
+ # encoded_John_test = John_test.map(prepare_dataset, num_proc=4)
248
+ # encoded_Negel_test = Negel_test.map(prepare_dataset, num_proc=4)
249
+
250
+ # encoded_train = train.map(prepare_dataset, num_proc=4)
251
+ # encoded_dev = dev.map(prepare_dataset, num_proc=4)
252
+ # p326_encoded_train = p326_300_dataset.map(prepare_dataset, num_proc=4)
253
+
254
+ # combine large p326 in to training set
255
+ # encoded_train = concatenate_datasets([encoded_train, p326_encoded_train])
256
+
257
+ # encoded_John_p326 = John_p326_test_dataset.map(prepare_dataset, num_proc=4)
258
+ # encoded_John_video = John_video_test_dataset.map(prepare_dataset, num_proc=4)
259
+
260
+ # pdb.set_trace()
261
+ import numpy as np
262
+
263
+ WER = evaluate.load("wer")
264
+
265
+ ## Whisper decoding
266
+
267
+ from transformers import (Seq2SeqTrainer, Seq2SeqTrainingArguments,
268
+ WhisperFeatureExtractor,
269
+ WhisperForConditionalGeneration, WhisperModel,
270
+ WhisperProcessor, WhisperTokenizer)
271
+
272
+ processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
273
+ model = WhisperForConditionalGeneration.from_pretrained(
274
+ fine_tuning_dir,
275
+ ).to("cuda:0")
276
+ # model = WhisperForConditionalGeneration.from_pretrained(
277
+ # "openai/whisper-medium",
278
+ # ).to("cuda:0")
279
+ tokenizer = WhisperTokenizer.from_pretrained(
280
+ "openai/whisper-medium", language="English", task="transcribe"
281
+ )
282
+
283
+ from pathlib import Path
284
+
285
+ id = Path(fine_tuning_dir).stem
286
+ pdb.set_trace()
287
+ # tokenizer.push_to_hub("KevinGeng/%s" % id)
288
+ # import pdb
289
+ feature_extractor = WhisperFeatureExtractor.from_pretrained(
290
+ "openai/whisper-medium"
291
+ )
292
+
293
+ def whisper_prepare_dataset(batch):
294
+ # load and resample audio data from 48 to 16kHz
295
+ audio = batch["audio"]
296
+
297
+ # compute log-Mel input features from input audio array
298
+ batch["input_features"] = feature_extractor(
299
+ audio["array"], sampling_rate=audio["sampling_rate"]
300
+ ).input_features[0]
301
+
302
+ # encode target text to label ids
303
+ batch["labels"] = tokenizer(batch["transcription"]).input_ids
304
+ return batch
305
+
306
+ torch.cuda.empty_cache()
307
+
308
+
309
+ def my_map_to_pred(batch):
310
+ # pdb.set_trace()
311
+ audio = batch["audio"]
312
+ input_features = processor(
313
+ audio["array"],
314
+ sampling_rate=audio["sampling_rate"],
315
+ return_tensors="pt",
316
+ ).input_features
317
+ # batch["reference"] = whisper_processor.tokenizer._normalize(batch['text'])
318
+ batch["reference"] = processor.tokenizer._normalize(batch["transcription"])
319
+
320
+ with torch.no_grad():
321
+ # predicted_ids = whisper_model.generate(input_features.to("cuda"))[0]
322
+ predicted_ids = model.generate(input_features.to("cuda"))[0]
323
+ transcription = model.decode(predicted_ids)
324
+ batch["prediction"] = model.tokenizer._normalize(transcription)
325
+ return batch
326
+
327
+
328
+ from dataclasses import dataclass
329
+ from typing import Any, Dict, List, Union
330
+
331
+ import torch
332
+
333
+
334
+ @dataclass
335
+ class DataCollatorSpeechSeq2SeqWithPadding:
336
+ processor: Any
337
+
338
+ def __call__(
339
+ self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
340
+ ) -> Dict[str, torch.Tensor]:
341
+ # split inputs and labels since they have to be of different lengths and need different padding methods
342
+ # first treat the audio inputs by simply returning torch tensors
343
+ input_features = [
344
+ {"input_features": feature["input_features"]}
345
+ for feature in features
346
+ ]
347
+ batch = self.processor.feature_extractor.pad(
348
+ input_features, return_tensors="pt"
349
+ )
350
+
351
+ # get the tokenized label sequences
352
+ label_features = [
353
+ {"input_ids": feature["labels"]} for feature in features
354
+ ]
355
+ # pad the labels to max length
356
+ labels_batch = self.processor.tokenizer.pad(
357
+ label_features, return_tensors="pt"
358
+ )
359
+
360
+ # replace padding with -100 to ignore loss correctly
361
+ labels = labels_batch["input_ids"].masked_fill(
362
+ labels_batch.attention_mask.ne(1), -100
363
+ )
364
+
365
+ # if bos token is appended in previous tokenization step,
366
+ # cut bos token here as it's append later anyways
367
+ if (
368
+ (labels[:, 0] == self.processor.tokenizer.bos_token_id)
369
+ .all()
370
+ .cpu()
371
+ .item()
372
+ ):
373
+ labels = labels[:, 1:]
374
+
375
+ batch["labels"] = labels
376
+
377
+ return batch
378
+
379
+
380
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
381
+
382
+
383
+ def compute_metrics(pred):
384
+ pred_ids = pred.predictions
385
+ label_ids = pred.label_ids
386
+
387
+ # replace -100 with the pad_token_id
388
+ label_ids[label_ids == -100] = tokenizer.pad_token_id
389
+
390
+ # we do not want to group tokens when computing the metrics
391
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
392
+ label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
393
+
394
+ wer = 100 * WER.compute(predictions=pred_str, references=label_str)
395
+
396
+ return {"wer": wer}
397
+
398
+ encoded_train = combine_train.map(whisper_prepare_dataset, num_proc=4)
399
+ encoded_dev = conbine_dev.map(whisper_prepare_dataset, num_proc=4)
400
+ encoded_test = conbine_test.map(whisper_prepare_dataset, num_proc=4)
401
+
402
+ # extra_test
403
+
404
+ encoded_ori_test = test.map(whisper_prepare_dataset, num_proc=4) # 7 / 16
405
+
406
+ encoded_Tony_test = Tony_test.map(whisper_prepare_dataset, num_proc=4) # 0 / 19
407
+
408
+ encoded_John_test = John_test.map(whisper_prepare_dataset, num_proc=4) # 0 / 30
409
+ # [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False]
410
+ encoded_Negel_test = Negel_test.map(whisper_prepare_dataset, num_proc=4) # 12 / 16
411
+ # [False, True, True, True, True, True, True, False, True, True, True, True, True, True, False, False]
412
+ encoded_Fary = Fary_PAL_test_dataset.map(whisper_prepare_dataset, num_proc=4) # 12 / 30
413
+ # [True, True, True, True, True, False, False, False, True, False, False, True, False, False, True, False, True, True, False, True, True, False, False, False, False, False, False, False, False, False]
414
+ encoded_healthy = healthy_test_dataset.map(whisper_prepare_dataset, num_proc=4) # 5 / 160
415
+ # [False, False, False, False, False, False, False, False, False, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, False, True, True, False, False, False, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False]
416
+
417
+ # Make sure the content Variablity
418
+ train_tuple = tuple(encoded_train['transcription'])
419
+ dev_tuple = tuple(encoded_dev['transcription'])
420
+ train_dev_tuple = (train_tuple + dev_tuple)
421
+
422
+ pdb.set_trace()
423
+ new_encoded_test = encoded_test.select(np.where(np.array([False if x in train_dev_tuple else True for x in encoded_test['transcription']]))[0])
424
+ new_encoded_ori_test = encoded_ori_test.select(np.where(np.array([False if x in train_dev_tuple else True for x in encoded_ori_test['transcription']]))[0])
425
+ new_encoded_Tony_test = encoded_Tony_test.select(np.where(np.array([False if x in train_dev_tuple else True for x in encoded_Tony_test['transcription']]))[0])
426
+ new_encoded_John_test = encoded_John_test.select(np.where(np.array([False if x in train_dev_tuple else True for x in encoded_John_test['transcription']]))[0])
427
+ new_encoded_Negel_test = encoded_Negel_test.select(np.where(np.array([False if x in train_dev_tuple else True for x in encoded_Negel_test['transcription']]))[0])
428
+ new_encoded_Fary = encoded_Fary.select(np.where(np.array([False if x in train_dev_tuple else True for x in encoded_Fary['transcription']]))[0])
429
+ new_encoded_healthy = encoded_healthy.select(np.where(np.array([False if x in train_dev_tuple else True for x in encoded_healthy['transcription']]))[0])
430
+ pdb.set_trace()
431
+ torch.cuda.empty_cache()
432
+
433
+ training_args = Seq2SeqTrainingArguments(
434
+ output_dir=fine_tuning_dir, # change to a repo name of your choice
435
+ per_device_train_batch_size=8,
436
+ gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
437
+ learning_rate=1e-5,
438
+ warmup_steps=50,
439
+ max_steps=1000,
440
+ gradient_checkpointing=True,
441
+ fp16=True,
442
+ evaluation_strategy="steps",
443
+ save_strategy="steps",
444
+ per_device_eval_batch_size=8,
445
+ predict_with_generate=True,
446
+ generation_max_length=512,
447
+ save_steps=20,
448
+ eval_steps=20,
449
+ logging_steps=10,
450
+ report_to=["tensorboard"],
451
+ load_best_model_at_end=True,
452
+ metric_for_best_model="wer",
453
+ greater_is_better=False,
454
+ save_total_limit=10,
455
+ push_to_hub=False,
456
+ )
457
+ from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
458
+
459
+ trainer = Seq2SeqTrainer(
460
+ args=training_args,
461
+ model=model,
462
+ train_dataset=encoded_train,
463
+ eval_dataset=encoded_dev,
464
+ data_collator=data_collator,
465
+ compute_metrics=compute_metrics,
466
+ tokenizer=processor.feature_extractor,
467
+ callbacks=[EarlyStoppingCallback(early_stopping_patience=10)],
468
+ )
469
+ # callbacks=[EvalLoggingCallback()]
470
+ pdb.set_trace()
471
+
472
+ result_dict = {
473
+ "Ori_Test": trainer.evaluate(encoded_ori_test),
474
+ "Tony_Test": trainer.evaluate(encoded_Tony_test),
475
+ "John_Test": trainer.evaluate(encoded_John_test),
476
+ "Negel_Test": trainer.evaluate(encoded_Negel_test),
477
+ "Zeroshot_Fary_Test": trainer.evaluate(encoded_Fary),
478
+ "Healthy_Test": trainer.evaluate(encoded_healthy),
479
+ }
480
+
481
+ # print(result_dict)
482
+
483
+ pdb.set_trace()
484
+ trainer.evaluate(encoded_test)
485
+ trainer.evaluate(new_encoded_test)
486
+ new_result_dict = {
487
+ "Ori_Test": trainer.evaluate(new_encoded_ori_test), # 'eval_wer': 12.345679012345679,
488
+ "Tony_Test": trainer.evaluate(new_encoded_Tony_test), # 'eval_wer': 25.0,
489
+ "John_Test": trainer.evaluate(new_encoded_John_test),
490
+ "Negel_Test": trainer.evaluate(new_encoded_Negel_test), # 2.08
491
+ "Zeroshot_Fary_Test": trainer.evaluate(new_encoded_Fary), ## 11.49
492
+ "Healthy_Test": trainer.evaluate(new_encoded_healthy),
493
+ }
494
+
495
+ print(new_result_dict)
496
+
497
+ # pdb.set_trace()
498
+ # result_dict = {
499
+ # "Ori_Test": trainer.evaluate(encoded_ori_test),
500
+ # "Tony_Test": trainer.evaluate(encoded_Tony_test),
501
+ # "John_Test": trainer.evaluate(encoded_John_test),
502
+ # "Negel_Test": trainer.evaluate(encoded_Negel_test),
503
+ # "Zeroshot_Fary_Test": trainer.evaluate(encoded_Fary),
504
+ # "Healthy_Test": trainer.evaluate(encoded_healthy),
505
+ # }
506
+
507
+ # pdb.set_trace()
508
+ # # Evaluation
509
+ # model.push_to_hub("KevinGeng/%s" % id)
local/post_processing.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ # Post processing module for data recording
3
+ # Author: Kevin Geng @Laronix, Sep. 2022
4
+
5
+ # Load log.csv, generate standard wav files with selected samplerate, and calculate stastitical features
6
+ '''
7
+
8
+ from random import sample
9
+ import librosa
10
+ import soundfile as sf
11
+ import numpy as np
12
+ import pdb
13
+ from pathlib import Path
14
+ import sys
15
+ import pandas as pd
16
+ indir = Path(sys.argv[1])
17
+ assert indir.exists() == True
18
+ wavs = Path(indir/Path("Audio_to_Evaluate")).glob("**/*.wav")
19
+ log = Path(indir/Path("log.csv"))
20
+
21
+ # x = np.loadtxt(log, dtype=str, delimiter=",")
22
+ x = pd.read_csv(log, header=0)
23
+
24
+ # y, sr = librosa.load("/home/kevingeng/laronix_automos/Julianna/Audio_to_evaluate/tmp0kgcdpi2.wav", sr=48000)
25
+ outdir = indir/Path("output")
26
+ # pdb.set_trace()
27
+ # outdir_clean = indir/Path("output_clean")
28
+ Path.mkdir(outdir, exist_ok=True)
29
+ # Path.mkdir(outdir_clean, exist_ok=True)
30
+ ## Capitalize E valuate
31
+ # for i, j in zip(x["Audio_to_Evaluate"], x["Reference_ID"]):
32
+ # y, sr = librosa.load(i, sr=48000)
33
+ # # kevin 1017 John's trial with original data.
34
+ # y_ = librosa.util.normalize(y, norm=5)
35
+ # y_cut, index = librosa.effects.trim(y_, top_db=30)
36
+ # # normalized and cut
37
+ # # pdb.set_trace()
38
+ # # sf.write(outdir/Path(str(indir)+"_"+ j +".wav"), y_cut, samplerate=sr)
39
+ # sf.write(outdir/Path(Path(indir).stem+"_"+ j +".wav"), y_cut, samplerate=sr)
40
+
41
+ def process_audio(file_path, ref_id, sr=48000, norm=5, top_db=30):
42
+ y, _ = librosa.load(file_path, sr=sr)
43
+ y_norm = librosa.util.normalize(y, norm=norm)
44
+ y_cut, _ = librosa.effects.trim(y_norm, top_db=top_db)
45
+ return y_cut
46
+
47
+ def save_audio(y_cut, ref_id, outdir, indir, sr=48000):
48
+ out_path = outdir / f"{Path(indir).stem}_{ref_id}.wav"
49
+ sf.write(out_path, y_cut, samplerate=sr)
50
+
51
+ def main(audio_files, ref_ids, outdir, indir):
52
+ for file_path, ref_id in zip(audio_files, ref_ids):
53
+ y_cut = process_audio(file_path, ref_id)
54
+ save_audio(y_cut, ref_id, outdir, indir)
55
+
56
+ main(x["Audio_to_Evaluate"], x["Reference_ID"], outdir, indir)
local/wer_plot_report.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import pandas as pd
3
+
4
+ import numpy as np
5
+
6
+ import matplotlib.pyplot as plt
7
+ import sys
8
+ import pdb
9
+
10
+ threshold = 0.3
11
+ if __name__ == "__main__":
12
+ wer_csv = sys.argv[1]
13
+ df = pd.read_csv(wer_csv)
14
+ fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(25, 15))
15
+
16
+ # Hist for distribution
17
+ ax[0].set_xlabel("Word Error Rate")
18
+ ax[0].set_ylabel("Counts")
19
+ ax[0].set_xlim(left=0.0, right=df['wer'].max())
20
+ ax[0].hist(df['wer'], bins=50)
21
+ ax[0].axvline(x=threshold, color="r")
22
+ # plt.savefig("hist.png")
23
+
24
+ # Line curve for each sentences
25
+ colors = ['green' if x < threshold else 'red' for x in df['wer']]
26
+
27
+ new_ids = [str(x).split('.')[0] for x in df['id']]
28
+ ax[1].set_xlabel("IDs")
29
+ ax[1].set_ylabel("Word Error Rate")
30
+ ax[1].scatter(new_ids, df['wer'], c=colors, marker='o')
31
+ ax[1].vlines(new_ids, ymin=0, ymax=df['wer'], colors='grey', linestyle='dotted', label='Vertical Lines')
32
+ ax[1].axhline(y=threshold, xmin=0, xmax=len(new_ids), color='r')
33
+
34
+ # ax[0].axhline(y=threshold, color="black")
35
+
36
+ # for i, v in enumerate(df['wer']):
37
+ # plt.text(str(df['id'][i]).split('.')[0], -2, str(df['id'][i]), ha='center', fontsize=3)
38
+
39
+ ax[1].set_xticklabels(new_ids, rotation=90, fontsize=10)
40
+ ax[1].tick_params(axis='x', width=20)
41
+ # ax[1].set_xlim(10, len(df['id']) + 10)
42
+ plt.tight_layout()
43
+ pdb.set_trace()
44
+ # fig.savefig("%s/%s.png"%(Path(sys.argv[1]).parent, sys.argv[1].split('/')[-1]), format='png')
45
+ fig.savefig("%s.png"%(sys.argv[1]), format='png')
local/whisper_fine_tuning_large_with_negel.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fine_tuning_dir = "/home/kevingeng/Disk2/laronix/laronix_automos/fine_tuned/SSD/model/Michael_52_with_Large_AVA_script_conv_train_conv_dev/checkpoint-60"
2
+ """
3
+ TODO:
4
+ + [x] Load Configuration
5
+ + [ ] Multi ASR Engine
6
+ + [ ] Batch / Real Time support
7
+ """
8
+ from pathlib import Path
9
+ from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC, AutoProcessor
10
+ from datasets import load_dataset, concatenate_datasets
11
+ from datasets import Dataset, Audio
12
+ import pdb
13
+ import string
14
+ import librosa
15
+ # local import
16
+ import sys
17
+
18
+ sys.path.append("src")
19
+ import torch
20
+ torch.cuda.set_device("cuda:0")
21
+ # token_model = AutoModelForCTC.from_pretrained(
22
+ # "facebook/wav2vec2-base-960h"
23
+ # )
24
+
25
+ # audio_dir= "/Users/kevingeng/Laronix/laronix_automos/data/Patient_sil_trim_16k_normed_5_snr_40/"
26
+ audio_dir ="./data/Patient_sil_trim_16k_normed_5_snr_40"
27
+ healthy_dir="./data/Healthy"
28
+ Fary_PAL_30="./data/Fary_PAL_p326_20230110_30"
29
+ John_p326 = "./data/John_p326/output"
30
+ John_video = "./data/20230103_video"
31
+ p326_300_dir ="./data/John_p326_large"
32
+ negel_152 = "./data/4_negal_152_clean_all"
33
+
34
+ michael3_52 = "data/3_michael_20230619_52"
35
+
36
+ patient_T = "data/Patient_T/Patient_T"
37
+ patient_L = "data/Patient_L/Patient_L"
38
+ P1tony = "data/Participant1_Tony_Recording/CLEAN_SENTENCES/CONVERSATIONAL/PAL"
39
+ P1tony_arthur = "data/Participant1_Tony_Recording/CLEAN_SENTENCES/SCRIPTED/Arthur_the_Rat/PAL"
40
+ P1tony_rainbow = "data/Participant1_Tony_Recording/CLEAN_SENTENCES/SCRIPTED/Rainbow_Passage/Laronix"
41
+
42
+ def dataclean(example):
43
+ # pdb.set_trace()
44
+ if example['audio']['sampling_rate'] != 16000:
45
+ resampled_audio = librosa.resample(y=example['audio']['array'],
46
+ orig_sr= example['audio']['sampling_rate'],
47
+ target_sr=16000)
48
+ # torchaudio.transforms.Resample(example['audio']['sampling_rate'], 16000)
49
+ # resampled_audio = resampler(example['audio']['array'])
50
+
51
+ return {"audio": {"path": example['audio']['path'], "array": resampled_audio, "sampling_rate": 16000},
52
+ "transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))}
53
+ else:
54
+ return {"transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))}
55
+
56
+ # patient_L_test_dataset = load_dataset("audiofolder", data_dir=patient_L, split="train")
57
+ # patient_L_test_dataset = patient_L_test_dataset.map(dataclean)
58
+
59
+ # patient_T_test_dataset = load_dataset("audiofolder", data_dir=patient_T, split="train")
60
+ # patient_T_test_dataset = patient_T_test_dataset.map(dataclean)
61
+
62
+ P1tony_dataset = load_dataset("audiofolder", data_dir=P1tony, split="train")
63
+ P1tony_dataset = P1tony_dataset.map(dataclean)
64
+
65
+ P3Micheal_dataset_52 = load_dataset("audiofolder", data_dir=michael3_52, split="train")
66
+ P3Micheal_dataset_52 = P3Micheal_dataset_52.map(dataclean)
67
+
68
+ # negel_152_dataset = load_dataset("audiofolder", data_dir=negel_152, split="train")
69
+ # negel_152_dataset = negel_152_dataset.map(dataclean)
70
+
71
+
72
+ # pdb.set_trace()
73
+ # P1tony_scripted1 = load_dataset("audiofolder", data_dir=P1tony_rainbow, split="train")
74
+ # P1tony_scripted2 = load_dataset("audiofolder", data_dir=P1tony_arthur, split="train")
75
+ # P1tony_scripted1 = P1tony_scripted1.map(dataclean)
76
+ # P1tony_scripted2 = P1tony_scripted2.map(dataclean)
77
+ # P1tony_scripted = concatenate_datasets([P1tony_scripted1, P1tony_scripted2])
78
+
79
+ # audio_dir ="/home/kevingeng/laronix/laronix_automos/data/Healthy"
80
+ # tgt_audio_dir= "/Users/kevingeng/Laronix/Dataset/Pneumatic/automos"
81
+
82
+ # Get Transcription, WER and PPM
83
+ """
84
+ TODO:
85
+ [DONE]: Automatic generating Config
86
+ """
87
+
88
+ import yaml
89
+ import argparse
90
+ import sys
91
+ from pathlib import Path
92
+
93
+ sys.path.append("./src")
94
+ import lightning_module
95
+ from UV import plot_UV, get_speech_interval
96
+ from transformers import pipeline
97
+ from rich.progress import track
98
+ from rich import print as rprint
99
+ import numpy as np
100
+ import jiwer
101
+ import pdb
102
+ import torch.nn as nn
103
+ import torch
104
+ import torchaudio
105
+ import gradio as gr
106
+ from sys import flags
107
+ from random import sample
108
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
109
+
110
+ import evaluate
111
+
112
+ wer = evaluate.load("wer")
113
+
114
+ # root_path = Path(__file__).parents[1]
115
+
116
+ class ChangeSampleRate(nn.Module):
117
+ def __init__(self, input_rate: int, output_rate: int):
118
+ super().__init__()
119
+ self.output_rate = output_rate
120
+ self.input_rate = input_rate
121
+
122
+ def forward(self, wav: torch.tensor) -> torch.tensor:
123
+ # Only accepts 1-channel waveform input
124
+ wav = wav.view(wav.size(0), -1)
125
+ new_length = wav.size(-1) * self.output_rate // self.input_rate
126
+ indices = torch.arange(new_length) * (
127
+ self.input_rate / self.output_rate
128
+ )
129
+ round_down = wav[:, indices.long()]
130
+ round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
131
+ output = round_down * (1.0 - indices.fmod(1.0)).unsqueeze(
132
+ 0
133
+ ) + round_up * indices.fmod(1.0).unsqueeze(0)
134
+ return output
135
+
136
+ # resample and clean text data
137
+ def dataclean(example):
138
+ # pdb.set_trace()
139
+ if example['audio']['sampling_rate'] != 16000:
140
+ resampled_audio = librosa.resample(y=example['audio']['array'],
141
+ orig_sr= example['audio']['sampling_rate'],
142
+ target_sr=16000)
143
+ # torchaudio.transforms.Resample(example['audio']['sampling_rate'], 16000)
144
+ # resampled_audio = resampler(example['audio']['array'])
145
+
146
+ return {"audio": {"path": example['audio']['path'], "array": resampled_audio, "sampling_rate": 16000},
147
+ "transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))}
148
+ else:
149
+ return {"transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))}
150
+
151
+ # processor = AutoFeatureExtractor.from_pretrained(
152
+ # "facebook/wav2vec2-base-960h"
153
+ # )
154
+ processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
155
+
156
+ def prepare_dataset(batch):
157
+ audio = batch["audio"]
158
+ batch = processor(audio["array"], sampling_rate = audio["sampling_rate"], text=batch['transcription'])
159
+ batch["input_length"] = len(batch["input_values"][0])
160
+ return batch
161
+
162
+ src_dataset = load_dataset("audiofolder", data_dir=audio_dir, split="train")
163
+ src_dataset = src_dataset.map(dataclean)
164
+ p326_300_dataset = load_dataset("audiofolder", data_dir=p326_300_dir, split="train")
165
+ p326_300_dataset = p326_300_dataset.map(dataclean)
166
+
167
+ # healthy_test_dataset = load_dataset("audiofolder", data_dir=healthy_dir, split='train')
168
+ # healthy_test_dataset = healthy_test_dataset.map(dataclean)
169
+
170
+ # Fary_PAL_test_dataset = load_dataset("audiofolder", data_dir=Fary_PAL_30, split='train')
171
+ # Fary_PAL_test_dataset = Fary_PAL_test_dataset.map(dataclean)
172
+
173
+ # John_p326_test_dataset = load_dataset("audiofolder", data_dir=John_p326, split='train')
174
+ # John_p326_test_dataset = John_p326_test_dataset.map(dataclean)
175
+
176
+ # John_video_test_dataset = load_dataset("audiofolder", data_dir=John_video, split='train')
177
+ # John_video_test_dataset = John_video_test_dataset.map(dataclean)
178
+
179
+ # pdb.set_trace()
180
+
181
+ def train_dev_test_split(dataset: Dataset, dev_rate=0.1, test_rate=0.1, seed=1):
182
+ """
183
+ input: dataset
184
+ dev_rate,
185
+ test_rate
186
+ seed
187
+ -------
188
+ Output:
189
+ dataset_dict{"train", "dev", "test"}
190
+ """
191
+ train_dev_test = dataset.train_test_split(test_size=test_rate, seed=seed)
192
+ test = train_dev_test["test"]
193
+ train_dev = train_dev_test['train']
194
+
195
+ # pdb.set_trace()
196
+ if len(train_dev) <= int(len(dataset)*dev_rate):
197
+ train = Dataset.from_dict({"audio": [], "transcription": []})
198
+ dev = train_dev
199
+ else:
200
+ train_dev = train_dev.train_test_split(test_size=int(len(dataset)*dev_rate), seed=seed)
201
+ train = train_dev['train']
202
+ dev = train_dev['test']
203
+ return train, dev, test
204
+
205
+ # pdb.set_trace()
206
+ # P1tony_train, P1tony_dev, P1tony_test = train_dev_test_split(P1tony_dataset, dev_rate=0.5, test_rate=0.5, seed=1)
207
+ # P1tony_train_ = concatenate_datasets([P1tony_train,P1tony_scripted])
208
+ # pdb.set_trace()
209
+
210
+ Michael_52_train, Michael_52_dev, Michael_52_test = train_dev_test_split(P3Micheal_dataset_52, dev_rate=0.1.5, test_rate=0.15, seed=1)
211
+
212
+ # train_dev / test
213
+ ds = src_dataset.train_test_split(test_size=0.1, seed=1)
214
+
215
+ # dataset_libri = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
216
+
217
+ train_dev = ds['train']
218
+ # train / dev
219
+ train_dev = train_dev.train_test_split(test_size=int(len(src_dataset)*0.1), seed=1)
220
+ # train/dev/test
221
+ train = train_dev['train']
222
+ test = ds['test']
223
+ dev = train_dev['test']
224
+
225
+ encoded_train = train.map(prepare_dataset, num_proc=4)
226
+ encoded_dev = dev.map(prepare_dataset, num_proc=4)
227
+ encoded_test = test.map(prepare_dataset, num_proc=4)
228
+ p326_encoded_train = p326_300_dataset.map(prepare_dataset, num_proc=4)
229
+
230
+ # combine large p326 in to training set
231
+ encoded_train = concatenate_datasets([encoded_train, p326_encoded_train])
232
+
233
+ # encoded_healthy = healthy_test_dataset.map(prepare_dataset, num_proc=4)
234
+ # encoded_Fary = Fary_PAL_test_dataset.map(prepare_dataset, num_proc=4)
235
+ # encoded_John_p326 = John_p326_test_dataset.map(prepare_dataset, num_proc=4)
236
+ # encoded_John_video = John_video_test_dataset.map(prepare_dataset, num_proc=4)
237
+
238
+ # encoded_P1tony_train = P1tony_train.map(prepare_dataset, num_proc=4)
239
+ # encoded_P1tony_dev = P1tony_dev.map(prepare_dataset, num_proc=4)
240
+ # encoded_P1tony_test = P1tony_test.map(prepare_dataset, num_proc=4)
241
+
242
+ # pdb.set_trace()
243
+ import numpy as np
244
+
245
+ WER = evaluate.load("wer")
246
+
247
+ ## Whisper decoding
248
+
249
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration, WhisperTokenizer, WhisperFeatureExtractor, Seq2SeqTrainingArguments, Seq2SeqTrainer, WhisperModel
250
+ processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
251
+ # model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium").to("cuda:0")
252
+ model = WhisperForConditionalGeneration.from_pretrained("./fine_tuned/whipser_medium_en_PAL300_step25_step2_VCTK/checkpoint-400", use_auth_token=True).to("cuda:0")
253
+ tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-medium", language="English", task="transcribe")
254
+
255
+ from pathlib import Path
256
+ id = Path(fine_tuning_dir).stem
257
+ pdb.set_trace()
258
+ tokenizer.push_to_hub("KevinGeng/%s"%id)
259
+ # import pdb
260
+ feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-medium")
261
+
262
+ def whisper_prepare_dataset(batch):
263
+ # load and resample audio data from 48 to 16kHz
264
+ audio = batch["audio"]
265
+
266
+ # compute log-Mel input features from input audio array
267
+ batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
268
+
269
+ # encode target text to label ids
270
+ batch["labels"] = tokenizer(batch["transcription"]).input_ids
271
+ return batch
272
+
273
+ torch.cuda.empty_cache()
274
+
275
+ def my_map_to_pred(batch):
276
+ # pdb.set_trace()
277
+ audio = batch["audio"]
278
+ input_features = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt").input_features
279
+ # batch["reference"] = whisper_processor.tokenizer._normalize(batch['text'])
280
+ batch["reference"] = processor.tokenizer._normalize(batch['transcription'])
281
+
282
+ with torch.no_grad():
283
+ # predicted_ids = whisper_model.generate(input_features.to("cuda"))[0]
284
+ predicted_ids = model.generate(input_features.to("cuda"))[0]
285
+ transcription = model.decode(predicted_ids)
286
+ batch["prediction"] = model.tokenizer._normalize(transcription)
287
+ return batch
288
+
289
+ import torch
290
+
291
+ from dataclasses import dataclass
292
+ from typing import Any, Dict, List, Union
293
+
294
+ @dataclass
295
+ class DataCollatorSpeechSeq2SeqWithPadding:
296
+ processor: Any
297
+
298
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
299
+ # split inputs and labels since they have to be of different lengths and need different padding methods
300
+ # first treat the audio inputs by simply returning torch tensors
301
+ input_features = [{"input_features": feature["input_features"]} for feature in features]
302
+ batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
303
+
304
+ # get the tokenized label sequences
305
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
306
+ # pad the labels to max length
307
+ labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
308
+
309
+ # replace padding with -100 to ignore loss correctly
310
+ labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
311
+
312
+ # if bos token is appended in previous tokenization step,
313
+ # cut bos token here as it's append later anyways
314
+ if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
315
+ labels = labels[:, 1:]
316
+
317
+ batch["labels"] = labels
318
+
319
+ return batch
320
+
321
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
322
+
323
+ def compute_metrics(pred):
324
+ pred_ids = pred.predictions
325
+ label_ids = pred.label_ids
326
+
327
+ # replace -100 with the pad_token_id
328
+ label_ids[label_ids == -100] = tokenizer.pad_token_id
329
+
330
+ # we do not want to group tokens when computing the metrics
331
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
332
+ label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
333
+
334
+ wer = 100 * WER.compute(predictions=pred_str, references=label_str)
335
+
336
+ return {"wer": wer}
337
+
338
+ # whisper_train = train.map(whisper_prepare_dataset, num_proc=4)
339
+ # pdb.set_trace()
340
+ whisper_train_large = encoded_train.map(whisper_prepare_dataset, num_proc=4)
341
+ whisper_dev = dev.map(whisper_prepare_dataset, num_proc=4)
342
+ whisper_test = test.map(whisper_prepare_dataset, num_proc=4)
343
+
344
+ encoded_Michael_52_train = Michael_52_train.map(whisper_prepare_dataset, num_proc=4)
345
+ encoded_Michael_52_dev = Michael_52_dev.map(whisper_prepare_dataset, num_proc=4)
346
+ encoded_Michael_52_test = Michael_52_test.map(whisper_prepare_dataset, num_proc=4)
347
+ # pdb.set_trace()
348
+ # # Add scirtped tony
349
+ # encoded_P1tony_train = P1tony_train_.map(whisper_prepare_dataset, num_proc=4)
350
+ # encoded_P1tony_dev = P1tony_dev.map(whisper_prepare_dataset, num_proc=4)
351
+ # encoded_P1tony_test = P1tony_test.map(whisper_prepare_dataset, num_proc=4)
352
+
353
+ # encode_negel_152_train = negel_152_train.map(whisper_prepare_dataset, num_proc=4)
354
+ # encode_negel_152_dev = negel_152_dev.map(whisper_prepare_dataset, num_proc=4)
355
+ # encode_negel_152_test = negel_152_test.map(whisper_prepare_dataset, num_proc=4)
356
+
357
+ # encoded_train_large = concatenate_datasets([whisper_train_large, encode_negel_152_train])
358
+ # encoded_dev_large = concatenate_datasets([whisper_dev, encode_negel_152_dev])
359
+
360
+ pdb.set_trace()
361
+ torch.cuda.empty_cache()
362
+
363
+ training_args = Seq2SeqTrainingArguments(
364
+ output_dir=fine_tuning_dir, # change to a repo name of your choice
365
+ per_device_train_batch_size=8,
366
+ gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
367
+ learning_rate=1e-5,
368
+ warmup_steps=50,
369
+ max_steps=1000,
370
+ gradient_checkpointing=True,
371
+ fp16=True,
372
+ evaluation_strategy="steps",
373
+ save_strategy="steps",
374
+ per_device_eval_batch_size=8,
375
+ predict_with_generate=True,
376
+ generation_max_length=512,
377
+ save_steps=10,
378
+ eval_steps=10,
379
+ logging_steps=10,
380
+ report_to=["tensorboard"],
381
+ load_best_model_at_end=True,
382
+ metric_for_best_model="wer",
383
+ greater_is_better=False,
384
+ save_total_limit=5,
385
+ push_to_hub=False,
386
+ )
387
+ from transformers import Trainer, TrainingArguments, EarlyStoppingCallback
388
+
389
+ # pdb.set_trace()
390
+ # # from transformers.trainer.callbacks import TensorBoardCallback
391
+ # class EvalLoggingCallback(TrainerCallback):
392
+ # def on_evaluate(self, args, state, control, metrics, **kwargs):
393
+ # print(f"Eval loss: {metrics['eval_loss']:.4f}, Accuracy: {metrics['eval_wer']:.4f}")
394
+
395
+ # pdb.set_trace()
396
+
397
+ trainer = Seq2SeqTrainer(
398
+ args=training_args,
399
+ model=model,
400
+ train_dataset=encoded_Michael_52_train,
401
+ eval_dataset=encoded_Michael_52_dev,
402
+ data_collator=data_collator,
403
+ compute_metrics=compute_metrics,
404
+ tokenizer=processor.feature_extractor,
405
+ callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
406
+
407
+ )
408
+ # callbacks=[EvalLoggingCallback()]
409
+ trainer.train()
410
+ # trainer.evaluate(encoded_P1tony_test, metrix_key_prefix="test")
411
+ # trainer.callback_handler.on_test_end(trainer, datasets=encoded_P1tony_test)
412
+
413
+
414
+ # ## Not fine tuned
415
+ # z_result = encoded_test.map(my_map_to_pred)·
416
+ # # pdb.set_trace()
417
+ # # 0.4692737430167598
418
+ # z = WER.compute(references=z_result['reference'], predictions=z_result['prediction'])
419
+
420
+ # z_hel_result = encoded_healthy.map(my_map_to_pred)
421
+ # #
422
+ # z_hel = WER.compute(references=z_hel_result['reference'], predictions=z_hel_result['prediction'])
423
+ # # 0.1591610117211598
424
+
425
+ # z_fary_result = encoded_Fary.map(my_map_to_pred)
426
+ # z_far = WER.compute(references=z_fary_result['reference'], predictions=z_fary_result['prediction'])
427
+ # # 0.1791044776119403
428
+
429
+
430
+ # z_john_p326_result = encoded_John_p326.map(my_map_to_pred)
431
+ # z_john_p326 = WER.compute(references=z_john_p326_result['reference'], predictions=z_john_p326_result['prediction'])
432
+ # # 0.4648241206030151
433
+
434
+ # # y_John_video= fine_tuned_trainer.predict(encoded_John_video)
435
+ # # metrics={'test_loss': 2.665189743041992, 'test_wer': 0.7222222222222222, 'test_runtime': 0.1633, 'test_samples_per_second': 48.979, 'test_steps_per_second': 6.122})
436
+ # pdb.set_trace()
437
+
438
+ # p326 training
439
+ # metrics={'test_loss': 0.4804028868675232, 'test_wer': 0.21787709497206703, 'test_runtime': 0.3594, 'test_samples_per_second': 44.517, 'test_steps_per_second': 5.565})
440
+ # hel metrics={'test_loss': 1.6363693475723267, 'test_wer': 0.17951881554595928, 'test_runtime': 3.8451, 'test_samples_per_second': 41.611, 'test_steps_per_second': 5.201})
441
+ # Fary: metrics={'test_loss': 1.4633615016937256, 'test_wer': 0.5572139303482587, 'test_runtime': 0.6627, 'test_samples_per_second': 45.27, 'test_steps_per_second': 6.036})
442
+ # p326 large: metrics={'test_loss': 0.6568527817726135, 'test_wer': 0.2889447236180904, 'test_runtime': 0.7169, 'test_samples_per_second': 51.613, 'test_steps_per_second': 6.975})
local/whisper_fine_tuning_michael_100.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fine_tuning_dir = "fine_tuned/SSD/model/Michael_100_with_Large_AVA_script_conv_train_conv_dev"
2
+ """
3
+ TODO:
4
+ + [x] Load Configuration
5
+ + [ ] Multi ASR Engine
6
+ + [ ] Batch / Real Time support
7
+ """
8
+ from pathlib import Path
9
+ from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC, AutoProcessor
10
+ from datasets import load_dataset, concatenate_datasets
11
+ from datasets import Dataset, Audio
12
+ import pdb
13
+ import string
14
+ import librosa
15
+ # local import
16
+ import sys
17
+
18
+ sys.path.append("src")
19
+ import torch
20
+ torch.cuda.set_device("cuda:0")
21
+ # token_model = AutoModelForCTC.from_pretrained(
22
+ # "facebook/wav2vec2-base-960h"
23
+ # )
24
+
25
+ # audio_dir= "/Users/kevingeng/Laronix/laronix_automos/data/Patient_sil_trim_16k_normed_5_snr_40/"
26
+ audio_dir ="./data/Patient_sil_trim_16k_normed_5_snr_40"
27
+ healthy_dir="./data/Healthy"
28
+ Fary_PAL_30="./data/Fary_PAL_p326_20230110_30"
29
+ John_p326 = "./data/John_p326/output"
30
+ John_video = "./data/20230103_video"
31
+ p326_300_dir ="./data/John_p326_large"
32
+ negel_152 = "./data/4_negal_152_clean_all"
33
+
34
+ michael3_52 = "data/3_michael_20230619_100"
35
+
36
+ patient_T = "data/Patient_T/Patient_T"
37
+ patient_L = "data/Patient_L/Patient_L"
38
+ P1tony = "data/Participant1_Tony_Recording/CLEAN_SENTENCES/CONVERSATIONAL/PAL"
39
+ P1tony_arthur = "data/Participant1_Tony_Recording/CLEAN_SENTENCES/SCRIPTED/Arthur_the_Rat/PAL"
40
+ P1tony_rainbow = "data/Participant1_Tony_Recording/CLEAN_SENTENCES/SCRIPTED/Rainbow_Passage/Laronix"
41
+
42
+ def dataclean(example):
43
+ # pdb.set_trace()
44
+ if example['audio']['sampling_rate'] != 16000:
45
+ resampled_audio = librosa.resample(y=example['audio']['array'],
46
+ orig_sr= example['audio']['sampling_rate'],
47
+ target_sr=16000)
48
+ # torchaudio.transforms.Resample(example['audio']['sampling_rate'], 16000)
49
+ # resampled_audio = resampler(example['audio']['array'])
50
+
51
+ return {"audio": {"path": example['audio']['path'], "array": resampled_audio, "sampling_rate": 16000},
52
+ "transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))}
53
+ else:
54
+ return {"transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))}
55
+
56
+ # patient_L_test_dataset = load_dataset("audiofolder", data_dir=patient_L, split="train")
57
+ # patient_L_test_dataset = patient_L_test_dataset.map(dataclean)
58
+
59
+ # patient_T_test_dataset = load_dataset("audiofolder", data_dir=patient_T, split="train")
60
+ # patient_T_test_dataset = patient_T_test_dataset.map(dataclean)
61
+
62
+ P1tony_dataset = load_dataset("audiofolder", data_dir=P1tony, split="train")
63
+ P1tony_dataset = P1tony_dataset.map(dataclean)
64
+
65
+ P3Micheal_dataset_52 = load_dataset("audiofolder", data_dir=michael3_52, split="train")
66
+ P3Micheal_dataset_52 = P3Micheal_dataset_52.map(dataclean)
67
+
68
+ # negel_152_dataset = load_dataset("audiofolder", data_dir=negel_152, split="train")
69
+ # negel_152_dataset = negel_152_dataset.map(dataclean)
70
+
71
+
72
+ # pdb.set_trace()
73
+ # P1tony_scripted1 = load_dataset("audiofolder", data_dir=P1tony_rainbow, split="train")
74
+ # P1tony_scripted2 = load_dataset("audiofolder", data_dir=P1tony_arthur, split="train")
75
+ # P1tony_scripted1 = P1tony_scripted1.map(dataclean)
76
+ # P1tony_scripted2 = P1tony_scripted2.map(dataclean)
77
+ # P1tony_scripted = concatenate_datasets([P1tony_scripted1, P1tony_scripted2])
78
+
79
+ # audio_dir ="/home/kevingeng/laronix/laronix_automos/data/Healthy"
80
+ # tgt_audio_dir= "/Users/kevingeng/Laronix/Dataset/Pneumatic/automos"
81
+
82
+ # Get Transcription, WER and PPM
83
+ """
84
+ TODO:
85
+ [DONE]: Automatic generating Config
86
+ """
87
+
88
+ import yaml
89
+ import argparse
90
+ import sys
91
+ from pathlib import Path
92
+
93
+ sys.path.append("./src")
94
+ import lightning_module
95
+ # from UV import plot_UV, get_speech_interval
96
+ from transformers import pipeline
97
+ from rich.progress import track
98
+ from rich import print as rprint
99
+ import numpy as np
100
+ import jiwer
101
+ import pdb
102
+ import torch.nn as nn
103
+ import torch
104
+ import torchaudio
105
+ import gradio as gr
106
+ from sys import flags
107
+ from random import sample
108
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
109
+
110
+ import evaluate
111
+
112
+ wer = evaluate.load("wer")
113
+
114
+ # root_path = Path(__file__).parents[1]
115
+
116
+ class ChangeSampleRate(nn.Module):
117
+ def __init__(self, input_rate: int, output_rate: int):
118
+ super().__init__()
119
+ self.output_rate = output_rate
120
+ self.input_rate = input_rate
121
+
122
+ def forward(self, wav: torch.tensor) -> torch.tensor:
123
+ # Only accepts 1-channel waveform input
124
+ wav = wav.view(wav.size(0), -1)
125
+ new_length = wav.size(-1) * self.output_rate // self.input_rate
126
+ indices = torch.arange(new_length) * (
127
+ self.input_rate / self.output_rate
128
+ )
129
+ round_down = wav[:, indices.long()]
130
+ round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
131
+ output = round_down * (1.0 - indices.fmod(1.0)).unsqueeze(
132
+ 0
133
+ ) + round_up * indices.fmod(1.0).unsqueeze(0)
134
+ return output
135
+
136
+ # resample and clean text data
137
+ def dataclean(example):
138
+ # pdb.set_trace()
139
+ if example['audio']['sampling_rate'] != 16000:
140
+ resampled_audio = librosa.resample(y=example['audio']['array'],
141
+ orig_sr= example['audio']['sampling_rate'],
142
+ target_sr=16000)
143
+ # torchaudio.transforms.Resample(example['audio']['sampling_rate'], 16000)
144
+ # resampled_audio = resampler(example['audio']['array'])
145
+
146
+ return {"audio": {"path": example['audio']['path'], "array": resampled_audio, "sampling_rate": 16000},
147
+ "transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))}
148
+ else:
149
+ return {"transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))}
150
+
151
+ # processor = AutoFeatureExtractor.from_pretrained(
152
+ # "facebook/wav2vec2-base-960h"
153
+ # )
154
+ processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
155
+
156
+ def prepare_dataset(batch):
157
+ audio = batch["audio"]
158
+ batch = processor(audio["array"], sampling_rate = audio["sampling_rate"], text=batch['transcription'])
159
+ batch["input_length"] = len(batch["input_values"][0])
160
+ return batch
161
+
162
+ src_dataset = load_dataset("audiofolder", data_dir=audio_dir, split="train")
163
+ src_dataset = src_dataset.map(dataclean)
164
+ p326_300_dataset = load_dataset("audiofolder", data_dir=p326_300_dir, split="train")
165
+ p326_300_dataset = p326_300_dataset.map(dataclean)
166
+
167
+ # healthy_test_dataset = load_dataset("audiofolder", data_dir=healthy_dir, split='train')
168
+ # healthy_test_dataset = healthy_test_dataset.map(dataclean)
169
+
170
+ # Fary_PAL_test_dataset = load_dataset("audiofolder", data_dir=Fary_PAL_30, split='train')
171
+ # Fary_PAL_test_dataset = Fary_PAL_test_dataset.map(dataclean)
172
+
173
+ # John_p326_test_dataset = load_dataset("audiofolder", data_dir=John_p326, split='train')
174
+ # John_p326_test_dataset = John_p326_test_dataset.map(dataclean)
175
+
176
+ # John_video_test_dataset = load_dataset("audiofolder", data_dir=John_video, split='train')
177
+ # John_video_test_dataset = John_video_test_dataset.map(dataclean)
178
+
179
+ # pdb.set_trace()
180
+
181
+ def train_dev_test_split(dataset: Dataset, dev_rate=0.1, test_rate=0.1, seed=1):
182
+ """
183
+ input: dataset
184
+ dev_rate,
185
+ test_rate
186
+ seed
187
+ -------
188
+ Output:
189
+ dataset_dict{"train", "dev", "test"}
190
+ """
191
+ train_dev_test = dataset.train_test_split(test_size=test_rate, seed=seed)
192
+ test = train_dev_test["test"]
193
+ train_dev = train_dev_test['train']
194
+
195
+ # pdb.set_trace()
196
+ if len(train_dev) <= int(len(dataset)*dev_rate):
197
+ train = Dataset.from_dict({"audio": [], "transcription": []})
198
+ dev = train_dev
199
+ else:
200
+ train_dev = train_dev.train_test_split(test_size=int(len(dataset)*dev_rate), seed=seed)
201
+ train = train_dev['train']
202
+ dev = train_dev['test']
203
+ return train, dev, test
204
+
205
+ # pdb.set_trace()
206
+ # P1tony_train, P1tony_dev, P1tony_test = train_dev_test_split(P1tony_dataset, dev_rate=0.5, test_rate=0.5, seed=1)
207
+ # P1tony_train_ = concatenate_datasets([P1tony_train,P1tony_scripted])
208
+ # pdb.set_trace()
209
+
210
+ Michael_52_train, Michael_52_dev, Michael_52_test = train_dev_test_split(P3Micheal_dataset_52, dev_rate=0.10, test_rate=0.1, seed=1)
211
+
212
+ # train_dev / test
213
+ ds = src_dataset.train_test_split(test_size=0.1, seed=1)
214
+
215
+ # dataset_libri = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
216
+
217
+ train_dev = ds['train']
218
+ # train / dev
219
+ train_dev = train_dev.train_test_split(test_size=int(len(src_dataset)*0.1), seed=1)
220
+ # train/dev/test
221
+ train = train_dev['train']
222
+ test = ds['test']
223
+ dev = train_dev['test']
224
+
225
+ encoded_train = train.map(prepare_dataset, num_proc=4)
226
+ encoded_dev = dev.map(prepare_dataset, num_proc=4)
227
+ encoded_test = test.map(prepare_dataset, num_proc=4)
228
+ p326_encoded_train = p326_300_dataset.map(prepare_dataset, num_proc=4)
229
+
230
+ # combine large p326 in to training set
231
+ encoded_train = concatenate_datasets([encoded_train, p326_encoded_train])
232
+
233
+ # encoded_healthy = healthy_test_dataset.map(prepare_dataset, num_proc=4)
234
+ # encoded_Fary = Fary_PAL_test_dataset.map(prepare_dataset, num_proc=4)
235
+ # encoded_John_p326 = John_p326_test_dataset.map(prepare_dataset, num_proc=4)
236
+ # encoded_John_video = John_video_test_dataset.map(prepare_dataset, num_proc=4)
237
+
238
+ # encoded_P1tony_train = P1tony_train.map(prepare_dataset, num_proc=4)
239
+ # encoded_P1tony_dev = P1tony_dev.map(prepare_dataset, num_proc=4)
240
+ # encoded_P1tony_test = P1tony_test.map(prepare_dataset, num_proc=4)
241
+
242
+ # pdb.set_trace()
243
+ import numpy as np
244
+
245
+ WER = evaluate.load("wer")
246
+
247
+ ## Whisper decoding
248
+
249
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration, WhisperTokenizer, WhisperFeatureExtractor, Seq2SeqTrainingArguments, Seq2SeqTrainer, WhisperModel
250
+ processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
251
+ # model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium").to("cuda:0")
252
+ model = WhisperForConditionalGeneration.from_pretrained("./fine_tuned/whipser_medium_en_PAL300_step25_step2_VCTK/checkpoint-400", use_auth_token=True).to("cuda:0")
253
+ tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-medium", language="English", task="transcribe")
254
+
255
+ from pathlib import Path
256
+ id = Path(fine_tuning_dir).stem
257
+ pdb.set_trace()
258
+ tokenizer.push_to_hub("KevinGeng/%s"%id)
259
+ # import pdb
260
+ feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-medium")
261
+
262
+ def whisper_prepare_dataset(batch):
263
+ # load and resample audio data from 48 to 16kHz
264
+ audio = batch["audio"]
265
+
266
+ # compute log-Mel input features from input audio array
267
+ batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
268
+
269
+ # encode target text to label ids
270
+ batch["labels"] = tokenizer(batch["transcription"]).input_ids
271
+ return batch
272
+
273
+ torch.cuda.empty_cache()
274
+
275
+ def my_map_to_pred(batch):
276
+ # pdb.set_trace()
277
+ audio = batch["audio"]
278
+ input_features = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt").input_features
279
+ # batch["reference"] = whisper_processor.tokenizer._normalize(batch['text'])
280
+ batch["reference"] = processor.tokenizer._normalize(batch['transcription'])
281
+
282
+ with torch.no_grad():
283
+ # predicted_ids = whisper_model.generate(input_features.to("cuda"))[0]
284
+ predicted_ids = model.generate(input_features.to("cuda"))[0]
285
+ transcription = model.decode(predicted_ids)
286
+ batch["prediction"] = model.tokenizer._normalize(transcription)
287
+ return batch
288
+
289
+ import torch
290
+
291
+ from dataclasses import dataclass
292
+ from typing import Any, Dict, List, Union
293
+
294
+ @dataclass
295
+ class DataCollatorSpeechSeq2SeqWithPadding:
296
+ processor: Any
297
+
298
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
299
+ # split inputs and labels since they have to be of different lengths and need different padding methods
300
+ # first treat the audio inputs by simply returning torch tensors
301
+ input_features = [{"input_features": feature["input_features"]} for feature in features]
302
+ batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
303
+
304
+ # get the tokenized label sequences
305
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
306
+ # pad the labels to max length
307
+ labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
308
+
309
+ # replace padding with -100 to ignore loss correctly
310
+ labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
311
+
312
+ # if bos token is appended in previous tokenization step,
313
+ # cut bos token here as it's append later anyways
314
+ if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
315
+ labels = labels[:, 1:]
316
+
317
+ batch["labels"] = labels
318
+
319
+ return batch
320
+
321
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
322
+
323
+ def compute_metrics(pred):
324
+ pred_ids = pred.predictions
325
+ label_ids = pred.label_ids
326
+
327
+ # replace -100 with the pad_token_id
328
+ label_ids[label_ids == -100] = tokenizer.pad_token_id
329
+
330
+ # we do not want to group tokens when computing the metrics
331
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
332
+ label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
333
+
334
+ wer = 100 * WER.compute(predictions=pred_str, references=label_str)
335
+
336
+ return {"wer": wer}
337
+
338
+ # whisper_train = train.map(whisper_prepare_dataset, num_proc=4)
339
+ # pdb.set_trace()
340
+ whisper_train_large = encoded_train.map(whisper_prepare_dataset, num_proc=4)
341
+ whisper_dev = dev.map(whisper_prepare_dataset, num_proc=4)
342
+ whisper_test = test.map(whisper_prepare_dataset, num_proc=4)
343
+
344
+ encoded_Michael_52_train = Michael_52_train.map(whisper_prepare_dataset, num_proc=4)
345
+ encoded_Michael_52_dev = Michael_52_dev.map(whisper_prepare_dataset, num_proc=4)
346
+ encoded_Michael_52_test = Michael_52_test.map(whisper_prepare_dataset, num_proc=4)
347
+ # pdb.set_trace()
348
+ # # Add scirtped tony
349
+ # encoded_P1tony_train = P1tony_train_.map(whisper_prepare_dataset, num_proc=4)
350
+ # encoded_P1tony_dev = P1tony_dev.map(whisper_prepare_dataset, num_proc=4)
351
+ # encoded_P1tony_test = P1tony_test.map(whisper_prepare_dataset, num_proc=4)
352
+
353
+ # encode_negel_152_train = negel_152_train.map(whisper_prepare_dataset, num_proc=4)
354
+ # encode_negel_152_dev = negel_152_dev.map(whisper_prepare_dataset, num_proc=4)
355
+ # encode_negel_152_test = negel_152_test.map(whisper_prepare_dataset, num_proc=4)
356
+
357
+ # encoded_train_large = concatenate_datasets([whisper_train_large, encode_negel_152_train])
358
+ # encoded_dev_large = concatenate_datasets([whisper_dev, encode_negel_152_dev])
359
+
360
+ pdb.set_trace()
361
+ torch.cuda.empty_cache()
362
+
363
+ training_args = Seq2SeqTrainingArguments(
364
+ output_dir=fine_tuning_dir, # change to a repo name of your choice
365
+ per_device_train_batch_size=8,
366
+ gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
367
+ learning_rate=1e-5,
368
+ warmup_steps=50,
369
+ max_steps=1000,
370
+ gradient_checkpointing=True,
371
+ fp16=True,
372
+ evaluation_strategy="steps",
373
+ save_strategy="steps",
374
+ per_device_eval_batch_size=8,
375
+ predict_with_generate=True,
376
+ generation_max_length=512,
377
+ save_steps=10,
378
+ eval_steps=10,
379
+ logging_steps=10,
380
+ report_to=["tensorboard"],
381
+ load_best_model_at_end=True,
382
+ metric_for_best_model="wer",
383
+ greater_is_better=False,
384
+ save_total_limit=5,
385
+ push_to_hub=False,
386
+ )
387
+ from transformers import Trainer, TrainingArguments, EarlyStoppingCallback
388
+
389
+ # pdb.set_trace()
390
+ # # from transformers.trainer.callbacks import TensorBoardCallback
391
+ # class EvalLoggingCallback(TrainerCallback):
392
+ # def on_evaluate(self, args, state, control, metrics, **kwargs):
393
+ # print(f"Eval loss: {metrics['eval_loss']:.4f}, Accuracy: {metrics['eval_wer']:.4f}")
394
+
395
+ # pdb.set_trace()
396
+
397
+ trainer = Seq2SeqTrainer(
398
+ args=training_args,
399
+ model=model,
400
+ train_dataset=encoded_Michael_52_train,
401
+ eval_dataset=encoded_Michael_52_dev,
402
+ data_collator=data_collator,
403
+ compute_metrics=compute_metrics,
404
+ tokenizer=processor.feature_extractor,
405
+ callbacks=[EarlyStoppingCallback(early_stopping_patience=10)],
406
+
407
+ )
408
+ # callbacks=[EvalLoggingCallback()]
409
+ trainer.train()
410
+ # trainer.evaluate(encoded_P1tony_test, metrix_key_prefix="test")
411
+ # trainer.callback_handler.on_test_end(trainer, datasets=encoded_P1tony_test)
412
+
413
+
414
+ # ## Not fine tuned
415
+ # z_result = encoded_test.map(my_map_to_pred)·
416
+ # # pdb.set_trace()
417
+ # # 0.4692737430167598
418
+ # z = WER.compute(references=z_result['reference'], predictions=z_result['prediction'])
419
+
420
+ # z_hel_result = encoded_healthy.map(my_map_to_pred)
421
+ # #
422
+ # z_hel = WER.compute(references=z_hel_result['reference'], predictions=z_hel_result['prediction'])
423
+ # # 0.1591610117211598
424
+
425
+ # z_fary_result = encoded_Fary.map(my_map_to_pred)
426
+ # z_far = WER.compute(references=z_fary_result['reference'], predictions=z_fary_result['prediction'])
427
+ # # 0.1791044776119403
428
+
429
+
430
+ # z_john_p326_result = encoded_John_p326.map(my_map_to_pred)
431
+ # z_john_p326 = WER.compute(references=z_john_p326_result['reference'], predictions=z_john_p326_result['prediction'])
432
+ # # 0.4648241206030151
433
+
434
+ # # y_John_video= fine_tuned_trainer.predict(encoded_John_video)
435
+ # # metrics={'test_loss': 2.665189743041992, 'test_wer': 0.7222222222222222, 'test_runtime': 0.1633, 'test_samples_per_second': 48.979, 'test_steps_per_second': 6.122})
436
+ # pdb.set_trace()
437
+
438
+ # p326 training
439
+ # metrics={'test_loss': 0.4804028868675232, 'test_wer': 0.21787709497206703, 'test_runtime': 0.3594, 'test_samples_per_second': 44.517, 'test_steps_per_second': 5.565})
440
+ # hel metrics={'test_loss': 1.6363693475723267, 'test_wer': 0.17951881554595928, 'test_runtime': 3.8451, 'test_samples_per_second': 41.611, 'test_steps_per_second': 5.201})
441
+ # Fary: metrics={'test_loss': 1.4633615016937256, 'test_wer': 0.5572139303482587, 'test_runtime': 0.6627, 'test_samples_per_second': 45.27, 'test_steps_per_second': 6.036})
442
+ # p326 large: metrics={'test_loss': 0.6568527817726135, 'test_wer': 0.2889447236180904, 'test_runtime': 0.7169, 'test_samples_per_second': 51.613, 'test_steps_per_second': 6.975})
local/whisper_fine_tuning_negel.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fine_tuning_dir = "fine_tuned/SSD/model/Negel_152_AVA_script_conv_train_conv_dev"
2
+ """
3
+ TODO:
4
+ + [x] Load Configuration
5
+ + [ ] Multi ASR Engine
6
+ + [ ] Batch / Real Time support
7
+ """
8
+ from pathlib import Path
9
+ from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC, AutoProcessor
10
+ from datasets import load_dataset, concatenate_datasets
11
+ from datasets import Dataset, Audio
12
+ import pdb
13
+ import string
14
+ import librosa
15
+ # local import
16
+ import sys
17
+
18
+ sys.path.append("src")
19
+ import torch
20
+ torch.cuda.set_device("cuda:0")
21
+ # token_model = AutoModelForCTC.from_pretrained(
22
+ # "facebook/wav2vec2-base-960h"
23
+ # )
24
+
25
+ # audio_dir= "/Users/kevingeng/Laronix/laronix_automos/data/Patient_sil_trim_16k_normed_5_snr_40/"
26
+ audio_dir ="./data/Patient_sil_trim_16k_normed_5_snr_40"
27
+ healthy_dir="./data/Healthy"
28
+ Fary_PAL_30="./data/Fary_PAL_p326_20230110_30"
29
+ John_p326 = "./data/John_p326/output"
30
+ John_video = "./data/20230103_video"
31
+ p326_300_dir ="./data/John_p326_large"
32
+
33
+ negel_152 = "./data/4_negal_152_clean_all"
34
+
35
+ patient_T = "data/Patient_T/Patient_T"
36
+ patient_L = "data/Patient_L/Patient_L"
37
+ P1tony = "data/Participant1_Tony_Recording/CLEAN_SENTENCES/CONVERSATIONAL/PAL"
38
+ P1tony_arthur = "data/Participant1_Tony_Recording/CLEAN_SENTENCES/SCRIPTED/Arthur_the_Rat/PAL"
39
+ P1tony_rainbow = "data/Participant1_Tony_Recording/CLEAN_SENTENCES/SCRIPTED/Rainbow_Passage/Laronix"
40
+
41
+ def dataclean(example):
42
+ # pdb.set_trace()
43
+ if example['audio']['sampling_rate'] != 16000:
44
+ resampled_audio = librosa.resample(y=example['audio']['array'],
45
+ orig_sr= example['audio']['sampling_rate'],
46
+ target_sr=16000)
47
+ # torchaudio.transforms.Resample(example['audio']['sampling_rate'], 16000)
48
+ # resampled_audio = resampler(example['audio']['array'])
49
+
50
+ return {"audio": {"path": example['audio']['path'], "array": resampled_audio, "sampling_rate": 16000},
51
+ "transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))}
52
+ else:
53
+ return {"transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))}
54
+
55
+ # patient_L_test_dataset = load_dataset("audiofolder", data_dir=patient_L, split="train")
56
+ # patient_L_test_dataset = patient_L_test_dataset.map(dataclean)
57
+
58
+ # patient_T_test_dataset = load_dataset("audiofolder", data_dir=patient_T, split="train")
59
+ # patient_T_test_dataset = patient_T_test_dataset.map(dataclean)
60
+
61
+ P1tony_dataset = load_dataset("audiofolder", data_dir=P1tony, split="train")
62
+ P1tony_dataset = P1tony_dataset.map(dataclean)
63
+
64
+ negel_152_dataset = load_dataset("audiofolder", data_dir=negel_152, split="train")
65
+ negel_152_dataset = negel_152_dataset.map(dataclean)
66
+ # pdb.set_trace()
67
+ # P1tony_scripted1 = load_dataset("audiofolder", data_dir=P1tony_rainbow, split="train")
68
+ # P1tony_scripted2 = load_dataset("audiofolder", data_dir=P1tony_arthur, split="train")
69
+ # P1tony_scripted1 = P1tony_scripted1.map(dataclean)
70
+ # P1tony_scripted2 = P1tony_scripted2.map(dataclean)
71
+ # P1tony_scripted = concatenate_datasets([P1tony_scripted1, P1tony_scripted2])
72
+
73
+ # audio_dir ="/home/kevingeng/laronix/laronix_automos/data/Healthy"
74
+ # tgt_audio_dir= "/Users/kevingeng/Laronix/Dataset/Pneumatic/automos"
75
+
76
+ # Get Transcription, WER and PPM
77
+ """
78
+ TODO:
79
+ [DONE]: Automatic generating Config
80
+ """
81
+
82
+ import yaml
83
+ import argparse
84
+ import sys
85
+ from pathlib import Path
86
+
87
+ sys.path.append("./src")
88
+ import lightning_module
89
+ from UV import plot_UV, get_speech_interval
90
+ from transformers import pipeline
91
+ from rich.progress import track
92
+ from rich import print as rprint
93
+ import numpy as np
94
+ import jiwer
95
+ import pdb
96
+ import torch.nn as nn
97
+ import torch
98
+ import torchaudio
99
+ import gradio as gr
100
+ from sys import flags
101
+ from random import sample
102
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
103
+
104
+ import evaluate
105
+
106
+ wer = evaluate.load("wer")
107
+
108
+ # root_path = Path(__file__).parents[1]
109
+
110
+ class ChangeSampleRate(nn.Module):
111
+ def __init__(self, input_rate: int, output_rate: int):
112
+ super().__init__()
113
+ self.output_rate = output_rate
114
+ self.input_rate = input_rate
115
+
116
+ def forward(self, wav: torch.tensor) -> torch.tensor:
117
+ # Only accepts 1-channel waveform input
118
+ wav = wav.view(wav.size(0), -1)
119
+ new_length = wav.size(-1) * self.output_rate // self.input_rate
120
+ indices = torch.arange(new_length) * (
121
+ self.input_rate / self.output_rate
122
+ )
123
+ round_down = wav[:, indices.long()]
124
+ round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
125
+ output = round_down * (1.0 - indices.fmod(1.0)).unsqueeze(
126
+ 0
127
+ ) + round_up * indices.fmod(1.0).unsqueeze(0)
128
+ return output
129
+
130
+ # resample and clean text data
131
+ def dataclean(example):
132
+ # pdb.set_trace()
133
+ if example['audio']['sampling_rate'] != 16000:
134
+ resampled_audio = librosa.resample(y=example['audio']['array'],
135
+ orig_sr= example['audio']['sampling_rate'],
136
+ target_sr=16000)
137
+ # torchaudio.transforms.Resample(example['audio']['sampling_rate'], 16000)
138
+ # resampled_audio = resampler(example['audio']['array'])
139
+
140
+ return {"audio": {"path": example['audio']['path'], "array": resampled_audio, "sampling_rate": 16000},
141
+ "transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))}
142
+ else:
143
+ return {"transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))}
144
+
145
+ # processor = AutoFeatureExtractor.from_pretrained(
146
+ # "facebook/wav2vec2-base-960h"
147
+ # )
148
+ processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
149
+
150
+ def prepare_dataset(batch):
151
+ audio = batch["audio"]
152
+ batch = processor(audio["array"], sampling_rate = audio["sampling_rate"], text=batch['transcription'])
153
+ batch["input_length"] = len(batch["input_values"][0])
154
+ return batch
155
+
156
+ # src_dataset = load_dataset("audiofolder", data_dir=audio_dir, split="train")
157
+ # src_dataset = src_dataset.map(dataclean)
158
+ # p326_300_dataset = load_dataset("audiofolder", data_dir=p326_300_dir, split="train")
159
+ # p326_300_dataset = p326_300_dataset.map(dataclean)
160
+
161
+
162
+ # healthy_test_dataset = load_dataset("audiofolder", data_dir=healthy_dir, split='train')
163
+ # healthy_test_dataset = healthy_test_dataset.map(dataclean)
164
+
165
+ # Fary_PAL_test_dataset = load_dataset("audiofolder", data_dir=Fary_PAL_30, split='train')
166
+ # Fary_PAL_test_dataset = Fary_PAL_test_dataset.map(dataclean)
167
+
168
+ # John_p326_test_dataset = load_dataset("audiofolder", data_dir=John_p326, split='train')
169
+ # John_p326_test_dataset = John_p326_test_dataset.map(dataclean)
170
+
171
+ # John_video_test_dataset = load_dataset("audiofolder", data_dir=John_video, split='train')
172
+ # John_video_test_dataset = John_video_test_dataset.map(dataclean)
173
+
174
+
175
+
176
+ # pdb.set_trace()
177
+
178
+ def train_dev_test_split(dataset: Dataset, dev_rate=0.1, test_rate=0.1, seed=1):
179
+ """
180
+ input: dataset
181
+ dev_rate,
182
+ test_rate
183
+ seed
184
+ -------
185
+ Output:
186
+ dataset_dict{"train", "dev", "test"}
187
+ """
188
+ train_dev_test = dataset.train_test_split(test_size=test_rate, seed=seed)
189
+ test = train_dev_test["test"]
190
+ train_dev = train_dev_test['train']
191
+
192
+ # pdb.set_trace()
193
+ if len(train_dev) <= int(len(dataset)*dev_rate):
194
+ train = Dataset.from_dict({"audio": [], "transcription": []})
195
+ dev = train_dev
196
+ else:
197
+ train_dev = train_dev.train_test_split(test_size=int(len(dataset)*dev_rate), seed=seed)
198
+ train = train_dev['train']
199
+ dev = train_dev['test']
200
+ return train, dev, test
201
+
202
+ # pdb.set_trace()
203
+ # P1tony_train, P1tony_dev, P1tony_test = train_dev_test_split(P1tony_dataset, dev_rate=0.5, test_rate=0.5, seed=1)
204
+ # P1tony_train_ = concatenate_datasets([P1tony_train,P1tony_scripted])
205
+ # pdb.set_trace()
206
+
207
+ negel_152_train, negel_152_dev, negel_152_test = train_dev_test_split(negel_152_dataset, dev_rate=0.1, test_rate=0.1, seed=1)
208
+
209
+ # train_dev / test
210
+ # ds = src_dataset.train_test_split(test_size=0.1, seed=1)
211
+
212
+ # dataset_libri = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
213
+
214
+ # train_dev = ds['train']
215
+ # # train / dev
216
+ # train_dev = train_dev.train_test_split(test_size=int(len(src_dataset)*0.1), seed=1)
217
+ # # train/dev/test
218
+ # train = train_dev['train']
219
+ # test = ds['test']
220
+ # dev = train_dev['test']
221
+
222
+ # encoded_train = train.map(prepare_dataset, num_proc=4)
223
+ # encoded_dev = dev.map(prepare_dataset, num_proc=4)
224
+ # encoded_test = test.map(prepare_dataset, num_proc=4)
225
+ # p326_encoded_train = p326_300_dataset.map(prepare_dataset, num_proc=4)
226
+
227
+ # # combine large p326 in to training set
228
+ # # encoded_train = concatenate_datasets([encoded_train, p326_encoded_train])
229
+
230
+ # encoded_healthy = healthy_test_dataset.map(prepare_dataset, num_proc=4)
231
+ # encoded_Fary = Fary_PAL_test_dataset.map(prepare_dataset, num_proc=4)
232
+ # encoded_John_p326 = John_p326_test_dataset.map(prepare_dataset, num_proc=4)
233
+ # encoded_John_video = John_video_test_dataset.map(prepare_dataset, num_proc=4)
234
+
235
+ # encoded_P1tony_train = P1tony_train.map(prepare_dataset, num_proc=4)
236
+ # encoded_P1tony_dev = P1tony_dev.map(prepare_dataset, num_proc=4)
237
+ # encoded_P1tony_test = P1tony_test.map(prepare_dataset, num_proc=4)
238
+
239
+ # pdb.set_trace()
240
+ import numpy as np
241
+
242
+ WER = evaluate.load("wer")
243
+
244
+ ## Whisper decoding
245
+
246
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration, WhisperTokenizer, WhisperFeatureExtractor, Seq2SeqTrainingArguments, Seq2SeqTrainer, WhisperModel
247
+ processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
248
+ # model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium").to("cuda:0")
249
+ model = WhisperForConditionalGeneration.from_pretrained("./fine_tuned/whipser_medium_en_PAL300_step25_step2_VCTK/checkpoint-400", use_auth_token=True).to("cuda:0")
250
+ tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-medium", language="English", task="transcribe")
251
+
252
+ from pathlib import Path
253
+ id = Path(fine_tuning_dir).stem
254
+ # pdb.set_trace()
255
+ tokenizer.push_to_hub("KevinGeng/%s"%id)
256
+ # import pdb
257
+ feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-medium")
258
+
259
+ def whisper_prepare_dataset(batch):
260
+ # load and resample audio data from 48 to 16kHz
261
+ audio = batch["audio"]
262
+
263
+ # compute log-Mel input features from input audio array
264
+ batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
265
+
266
+ # encode target text to label ids
267
+ batch["labels"] = tokenizer(batch["transcription"]).input_ids
268
+ return batch
269
+
270
+ torch.cuda.empty_cache()
271
+
272
+ def my_map_to_pred(batch):
273
+ # pdb.set_trace()
274
+ audio = batch["audio"]
275
+ input_features = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt").input_features
276
+ # batch["reference"] = whisper_processor.tokenizer._normalize(batch['text'])
277
+ batch["reference"] = processor.tokenizer._normalize(batch['transcription'])
278
+
279
+ with torch.no_grad():
280
+ # predicted_ids = whisper_model.generate(input_features.to("cuda"))[0]
281
+ predicted_ids = model.generate(input_features.to("cuda"))[0]
282
+ transcription = model.decode(predicted_ids)
283
+ batch["prediction"] = model.tokenizer._normalize(transcription)
284
+ return batch
285
+
286
+ import torch
287
+
288
+ from dataclasses import dataclass
289
+ from typing import Any, Dict, List, Union
290
+
291
+ @dataclass
292
+ class DataCollatorSpeechSeq2SeqWithPadding:
293
+ processor: Any
294
+
295
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
296
+ # split inputs and labels since they have to be of different lengths and need different padding methods
297
+ # first treat the audio inputs by simply returning torch tensors
298
+ input_features = [{"input_features": feature["input_features"]} for feature in features]
299
+ batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
300
+
301
+ # get the tokenized label sequences
302
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
303
+ # pad the labels to max length
304
+ labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
305
+
306
+ # replace padding with -100 to ignore loss correctly
307
+ labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
308
+
309
+ # if bos token is appended in previous tokenization step,
310
+ # cut bos token here as it's append later anyways
311
+ if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
312
+ labels = labels[:, 1:]
313
+
314
+ batch["labels"] = labels
315
+
316
+ return batch
317
+
318
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
319
+
320
+ def compute_metrics(pred):
321
+ pred_ids = pred.predictions
322
+ label_ids = pred.label_ids
323
+
324
+ # replace -100 with the pad_token_id
325
+ label_ids[label_ids == -100] = tokenizer.pad_token_id
326
+
327
+ # we do not want to group tokens when computing the metrics
328
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
329
+ label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
330
+
331
+ wer = 100 * WER.compute(predictions=pred_str, references=label_str)
332
+
333
+ return {"wer": wer}
334
+
335
+ # whisper_train = train.map(whisper_prepare_dataset, num_proc=4)
336
+ # pdb.set_trace()
337
+ # whisper_train_large = encoded_train.map(whisper_prepare_dataset, num_proc=4)
338
+ # whisper_dev = dev.map(whisper_prepare_dataset, num_proc=4)
339
+ # whisper_test = test.map(whisper_prepare_dataset, num_proc=4)
340
+ # pdb.set_trace()
341
+ # # Add scirtped tony
342
+ # encoded_P1tony_train = P1tony_train_.map(whisper_prepare_dataset, num_proc=4)
343
+ # encoded_P1tony_dev = P1tony_dev.map(whisper_prepare_dataset, num_proc=4)
344
+ # encoded_P1tony_test = P1tony_test.map(whisper_prepare_dataset, num_proc=4)
345
+
346
+ encode_negel_152_train = negel_152_train.map(whisper_prepare_dataset, num_proc=4)
347
+ encode_negel_152_dev = negel_152_dev.map(whisper_prepare_dataset, num_proc=4)
348
+ encode_negel_152_test = negel_152_test.map(whisper_prepare_dataset, num_proc=4)
349
+ pdb.set_trace()
350
+ torch.cuda.empty_cache()
351
+
352
+ training_args = Seq2SeqTrainingArguments(
353
+ output_dir=fine_tuning_dir, # change to a repo name of your choice
354
+ per_device_train_batch_size=8,
355
+ gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
356
+ learning_rate=1e-5,
357
+ warmup_steps=50,
358
+ max_steps=1000,
359
+ gradient_checkpointing=True,
360
+ fp16=True,
361
+ evaluation_strategy="steps",
362
+ save_strategy="steps",
363
+ per_device_eval_batch_size=8,
364
+ predict_with_generate=True,
365
+ generation_max_length=512,
366
+ save_steps=10,
367
+ eval_steps=10,
368
+ logging_steps=10,
369
+ report_to=["tensorboard"],
370
+ load_best_model_at_end=True,
371
+ metric_for_best_model="wer",
372
+ greater_is_better=False,
373
+ save_total_limit=5,
374
+ push_to_hub=True,
375
+ )
376
+ from transformers import Trainer, TrainingArguments, EarlyStoppingCallback
377
+
378
+ # pdb.set_trace()
379
+ # # from transformers.trainer.callbacks import TensorBoardCallback
380
+ # class EvalLoggingCallback(TrainerCallback):
381
+ # def on_evaluate(self, args, state, control, metrics, **kwargs):
382
+ # print(f"Eval loss: {metrics['eval_loss']:.4f}, Accuracy: {metrics['eval_wer']:.4f}")
383
+
384
+ # pdb.set_trace()
385
+
386
+ trainer = Seq2SeqTrainer(
387
+ args=training_args,
388
+ model=model,
389
+ train_dataset=encode_negel_152_train,
390
+ eval_dataset=encode_negel_152_dev,
391
+ data_collator=data_collator,
392
+ compute_metrics=compute_metrics,
393
+ tokenizer=processor.feature_extractor,
394
+ callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
395
+
396
+ )
397
+ # callbacks=[EvalLoggingCallback()]
398
+ trainer.train()
399
+ # trainer.evaluate(encoded_P1tony_test, metrix_key_prefix="test")
400
+ # trainer.callback_handler.on_test_end(trainer, datasets=encoded_P1tony_test)
401
+
402
+
403
+ # ## Not fine tuned
404
+ # z_result = encoded_test.map(my_map_to_pred)·
405
+ # # pdb.set_trace()
406
+ # # 0.4692737430167598
407
+ # z = WER.compute(references=z_result['reference'], predictions=z_result['prediction'])
408
+
409
+ # z_hel_result = encoded_healthy.map(my_map_to_pred)
410
+ # #
411
+ # z_hel = WER.compute(references=z_hel_result['reference'], predictions=z_hel_result['prediction'])
412
+ # # 0.1591610117211598
413
+
414
+ # z_fary_result = encoded_Fary.map(my_map_to_pred)
415
+ # z_far = WER.compute(references=z_fary_result['reference'], predictions=z_fary_result['prediction'])
416
+ # # 0.1791044776119403
417
+
418
+
419
+ # z_john_p326_result = encoded_John_p326.map(my_map_to_pred)
420
+ # z_john_p326 = WER.compute(references=z_john_p326_result['reference'], predictions=z_john_p326_result['prediction'])
421
+ # # 0.4648241206030151
422
+
423
+ # # y_John_video= fine_tuned_trainer.predict(encoded_John_video)
424
+ # # metrics={'test_loss': 2.665189743041992, 'test_wer': 0.7222222222222222, 'test_runtime': 0.1633, 'test_samples_per_second': 48.979, 'test_steps_per_second': 6.122})
425
+ # pdb.set_trace()
426
+
427
+ # p326 training
428
+ # metrics={'test_loss': 0.4804028868675232, 'test_wer': 0.21787709497206703, 'test_runtime': 0.3594, 'test_samples_per_second': 44.517, 'test_steps_per_second': 5.565})
429
+ # hel metrics={'test_loss': 1.6363693475723267, 'test_wer': 0.17951881554595928, 'test_runtime': 3.8451, 'test_samples_per_second': 41.611, 'test_steps_per_second': 5.201})
430
+ # Fary: metrics={'test_loss': 1.4633615016937256, 'test_wer': 0.5572139303482587, 'test_runtime': 0.6627, 'test_samples_per_second': 45.27, 'test_steps_per_second': 6.036})
431
+ # p326 large: metrics={'test_loss': 0.6568527817726135, 'test_wer': 0.2889447236180904, 'test_runtime': 0.7169, 'test_samples_per_second': 51.613, 'test_steps_per_second': 6.975})
local/whisper_fine_tuning_negel_decode.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fine_tuning_dir = "fine_tuned/SSD/model/Michael_100_with_Large_AVA_script_conv_train_conv_dev/checkpoint-100"
2
+ # fine_tuning_dir = "fine_tuned/SSD/model/Michael_52_with_Large_AVA_script_conv_train_conv_dev/checkpoint-60"
3
+ # fine_tuning_dir = "fine_tuned/SSD/model/Negel_152_AVA_script_conv_train_conv_dev/checkpoint-100"
4
+ # fine_tuning_dir = "fine_tuned/SSD/model/Tony1_AVA_script_conv_train_conv_dev/checkpoint-160"
5
+ # fine_tuning_dir = "fine_tuned/SSD/model/Negel_with_Large_AVA_script_conv_train_conv_dev/checkpoint-210"
6
+
7
+ """
8
+ TODO:
9
+ + [x] Whipser Fine Tuned Model Evalutation
10
+ + [ ]
11
+ + [ ] Batch / Real Time support
12
+ """
13
+ from typing import Any, Dict, List, Union
14
+ from dataclasses import dataclass
15
+ from transformers import Seq2SeqTrainer
16
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration, WhisperTokenizer, WhisperFeatureExtractor, Seq2SeqTrainingArguments, Seq2SeqTrainer, WhisperModel
17
+ import evaluate
18
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
19
+ from random import sample
20
+ from sys import flags
21
+ import gradio as gr
22
+ import torchaudio
23
+ import torch.nn as nn
24
+ import jiwer
25
+ import numpy as np
26
+ from rich import print as rprint
27
+ from rich.progress import track
28
+ from transformers import pipeline
29
+ import argparse
30
+ import yaml
31
+ import torch
32
+ from pathlib import Path
33
+ from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC, AutoProcessor
34
+ from datasets import load_dataset, concatenate_datasets
35
+ from datasets import Dataset, Audio
36
+ import pdb
37
+ import string
38
+ import librosa
39
+ # local import
40
+ import sys
41
+
42
+ sys.path.append("src")
43
+ import lightning_module
44
+
45
+ torch.cuda.set_device("cuda:0")
46
+
47
+ audio_dir = "./data/Patient_sil_trim_16k_normed_5_snr_40"
48
+ healthy_dir = "./data/Healthy"
49
+ Fary_PAL_30 = "./data/Fary_PAL_p326_20230110_30"
50
+ John_p326 = "./data/John_p326/output"
51
+ John_video = "./data/20230103_video"
52
+ negel_79 = "./data/4_negal_152_clean_all"
53
+ negel_152 = "./data/4_negal_152_clean_all"
54
+ P1tony = "data/Participant1_Tony_Recording/CLEAN_SENTENCES/CONVERSATIONAL/PAL"
55
+
56
+ michael3_52 = "data/3_michael_20230619_100"
57
+
58
+ patient_T = "data/Patient_T/Patient_T"
59
+ patient_L = "data/Patient_L/Patient_L"
60
+ # Get Transcription, WER and PPM
61
+ """
62
+ TODO:
63
+ [DONE]: Automatic generating Config
64
+ """
65
+
66
+
67
+ sys.path.append("./src")
68
+
69
+ wer = evaluate.load("wer")
70
+
71
+ # root_path = Path(__file__).parents[1]
72
+
73
+ class ChangeSampleRate(nn.Module):
74
+ def __init__(self, input_rate: int, output_rate: int):
75
+ super().__init__()
76
+ self.output_rate = output_rate
77
+ self.input_rate = input_rate
78
+
79
+ def forward(self, wav: torch.tensor) -> torch.tensor:
80
+ # Only accepts 1-channel waveform input
81
+ wav = wav.view(wav.size(0), -1)
82
+ new_length = wav.size(-1) * self.output_rate // self.input_rate
83
+ indices = torch.arange(new_length) * (
84
+ self.input_rate / self.output_rate
85
+ )
86
+ round_down = wav[:, indices.long()]
87
+ round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
88
+ output = round_down * (1.0 - indices.fmod(1.0)).unsqueeze(
89
+ 0
90
+ ) + round_up * indices.fmod(1.0).unsqueeze(0)
91
+ return output
92
+
93
+ # resample and clean text data
94
+
95
+
96
+ def dataclean(example):
97
+ # pdb.set_trace()
98
+ if example['audio']['sampling_rate'] != 16000:
99
+ resampled_audio = librosa.resample(y=example['audio']['array'],
100
+ orig_sr=example['audio']['sampling_rate'],
101
+ target_sr=16000)
102
+ # torchaudio.transforms.Resample(example['audio']['sampling_rate'], 16000)
103
+ # resampled_audio = resampler(example['audio']['array'])
104
+
105
+ return {"audio": {"path": example['audio']['path'], "array": resampled_audio, "sampling_rate": 16000},
106
+ "transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))}
107
+ else:
108
+ return {"transcription": example["transcription"].upper().translate(str.maketrans('', '', string.punctuation))}
109
+
110
+ processor = AutoFeatureExtractor.from_pretrained(
111
+ "facebook/wav2vec2-base-960h"
112
+ )
113
+
114
+ def prepare_dataset(batch):
115
+ audio = batch["audio"]
116
+ batch = processor(
117
+ audio["array"], sampling_rate=audio["sampling_rate"], text=batch['transcription'])
118
+ batch["input_length"] = len(batch["input_values"][0])
119
+ return batch
120
+
121
+
122
+ negel_79_dataset = load_dataset("audiofolder", data_dir=negel_79, split="train")
123
+ negel_79_dataset = negel_79_dataset.map(dataclean)
124
+
125
+ def train_dev_test_split(dataset: Dataset, dev_rate=0.1, test_rate=0.1, seed=1):
126
+ """
127
+ input: dataset
128
+ dev_rate,
129
+ test_rate
130
+ seed
131
+ -------
132
+ Output:
133
+ dataset_dict{"train", "dev", "test"}
134
+ """
135
+ train_dev_test = dataset.train_test_split(test_size=test_rate, seed=seed)
136
+ test = train_dev_test["test"]
137
+ train_dev = train_dev_test['train']
138
+
139
+ # pdb.set_trace()
140
+ if len(train_dev) <= int(len(dataset)*dev_rate):
141
+ train = Dataset.from_dict({"audio": [], "transcription": []})
142
+ dev = train_dev
143
+ else:
144
+ train_dev = train_dev.train_test_split(test_size=int(len(dataset)*dev_rate), seed=seed)
145
+ train = train_dev['train']
146
+ dev = train_dev['test']
147
+ return train, dev, test
148
+ P1tony_dataset = load_dataset("audiofolder", data_dir=P1tony, split="train")
149
+ P1tony_dataset = P1tony_dataset.map(dataclean)
150
+ # pdb.set_trace()
151
+ P1tony_train, P1tony_dev, P1tony_test = train_dev_test_split(P1tony_dataset, dev_rate=0.5, test_rate=0.5, seed=1)
152
+
153
+ P3Micheal_dataset_52 = load_dataset("audiofolder", data_dir=michael3_52, split="train")
154
+ P3Micheal_dataset_52 = P3Micheal_dataset_52.map(dataclean)
155
+
156
+ Michael_52_train, Michael_52_dev, Michael_52_test = train_dev_test_split(P3Micheal_dataset_52, dev_rate=0.1, test_rate=0.1, seed=1)
157
+
158
+ # P1tony_train_ = concatenate_datasets([P1tony_train,P1tony_scripted])
159
+ # pdb.set_trace()
160
+
161
+ # Negel_79_train, Negel_79_dev, Negel_79_test = train_dev_test_split(negel_79_dataset, dev_rate=0.1, test_rate=0.1, seed=1)
162
+
163
+ src_dataset = load_dataset("audiofolder", data_dir=audio_dir, split="train")
164
+ src_dataset = src_dataset.map(dataclean)
165
+
166
+ negel_152_dataset = load_dataset("audiofolder", data_dir=negel_152, split="train")
167
+ negel_152_dataset = negel_152_dataset.map(dataclean)
168
+
169
+ healthy_test_dataset = load_dataset(
170
+ "audiofolder", data_dir=healthy_dir, split='train')
171
+ healthy_test_dataset = healthy_test_dataset.map(dataclean)
172
+
173
+ Fary_PAL_test_dataset = load_dataset(
174
+ "audiofolder", data_dir=Fary_PAL_30, split='train')
175
+ Fary_PAL_test_dataset = Fary_PAL_test_dataset.map(dataclean)
176
+
177
+ John_p326_test_dataset = load_dataset(
178
+ "audiofolder", data_dir=John_p326, split='train')
179
+ John_p326_test_dataset = John_p326_test_dataset.map(dataclean)
180
+
181
+ John_video_test_dataset = load_dataset(
182
+ "audiofolder", data_dir=John_video, split='train')
183
+ John_video_test_dataset = John_video_test_dataset.map(dataclean)
184
+
185
+ patient_T_test_dataset = load_dataset("audiofolder", data_dir=patient_T, split='train')
186
+ patient_T_test_dataset = patient_T_test_dataset.map(dataclean)
187
+
188
+ patient_L_test_dataset = load_dataset("audiofolder", data_dir=patient_L, split='train')
189
+ patient_L_test_dataset = patient_L_test_dataset.map(dataclean)
190
+ pdb.set_trace()
191
+
192
+ # train_dev / test
193
+ ds = src_dataset.train_test_split(test_size=0.1, seed=1)
194
+
195
+ dataset_libri = load_dataset(
196
+ "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
197
+
198
+ train_dev = ds['train']
199
+ # train / dev
200
+ train_dev = train_dev.train_test_split(
201
+ test_size=int(len(src_dataset)*0.1), seed=1)
202
+ # train/dev/test
203
+ train = train_dev['train']
204
+ test = ds['test']
205
+ dev = train_dev['test']
206
+
207
+ # # pdb.set_trace()
208
+ encoded_train = train.map(prepare_dataset, num_proc=4)
209
+ encoded_dev = dev.map(prepare_dataset, num_proc=4)
210
+ encoded_test = test.map(prepare_dataset, num_proc=4)
211
+
212
+ encoded_Tony_test = P1tony_test.map(prepare_dataset, num_proc=4)
213
+ encoded_healthy = healthy_test_dataset.map(prepare_dataset, num_proc=4)
214
+ encoded_Fary = Fary_PAL_test_dataset.map(prepare_dataset, num_proc=4)
215
+ encoded_John_p326 = John_p326_test_dataset.map(prepare_dataset, num_proc=4)
216
+ encoded_John_video = John_video_test_dataset.map(prepare_dataset, num_proc=4)
217
+
218
+ # pdb.set_trace()
219
+
220
+ WER = evaluate.load("wer")
221
+
222
+ # Whisper decoding
223
+
224
+ processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
225
+ model = WhisperForConditionalGeneration.from_pretrained(
226
+ "openai/whisper-medium").to("cuda:0")
227
+ tokenizer = WhisperTokenizer.from_pretrained(
228
+ "openai/whisper-medium", language="English", task="transcribe")
229
+
230
+ # Need to push tokenizer to hugginface/model to activate online API
231
+
232
+ # tokenizer.push_to_hub("KevinGeng/whipser_medium_en_PAL300_step25")
233
+ # import pdb
234
+ # pdb.set_trace()
235
+
236
+ feature_extractor = WhisperFeatureExtractor.from_pretrained(
237
+ "openai/whisper-medium")
238
+
239
+
240
+ def whisper_prepare_dataset(batch):
241
+ # load and resample audio data from 48 to 16kHz
242
+ audio = batch["audio"]
243
+
244
+ # compute log-Mel input features from input audio array
245
+ batch["input_features"] = feature_extractor(
246
+ audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
247
+
248
+ # encode target text to label ids
249
+ batch["labels"] = tokenizer(batch["transcription"]).input_ids
250
+ return batch
251
+
252
+
253
+ torch.cuda.empty_cache()
254
+
255
+ training_args = Seq2SeqTrainingArguments(
256
+ # change to a repo name of your choice
257
+ output_dir="./whisper-medium-PAL128-25step",
258
+ per_device_train_batch_size=8,
259
+ gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
260
+ learning_rate=1e-5,
261
+ warmup_steps=100,
262
+ max_steps=1000,
263
+ gradient_checkpointing=True,
264
+ fp16=True,
265
+ evaluation_strategy="steps",
266
+ per_device_eval_batch_size=8,
267
+ predict_with_generate=True,
268
+ generation_max_length=512,
269
+ save_steps=100,
270
+ eval_steps=25,
271
+ logging_steps=100,
272
+ report_to=["tensorboard"],
273
+ load_best_model_at_end=True,
274
+ metric_for_best_model="wer",
275
+ greater_is_better=False,
276
+ push_to_hub=True,
277
+ )
278
+
279
+
280
+ def my_map_to_pred(batch):
281
+ # pdb.set_trace()
282
+ audio = batch["audio"]
283
+ input_features = processor(
284
+ audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt").input_features
285
+ # batch["reference"] = whisper_processor.tokenizer._normalize(batch['text'])
286
+ batch["reference"] = processor.tokenizer._normalize(batch['transcription'])
287
+
288
+ with torch.no_grad():
289
+ # predicted_ids = whisper_model.generate(input_features.to("cuda"))[0]
290
+ predicted_ids = model.generate(input_features.to("cuda"))[0]
291
+ transcription = model.decode(predicted_ids)
292
+ batch["prediction"] = model.tokenizer._normalize(transcription)
293
+ return batch
294
+
295
+
296
+ @dataclass
297
+ class DataCollatorSpeechSeq2SeqWithPadding:
298
+ processor: Any
299
+
300
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
301
+ # split inputs and labels since they have to be of different lengths and need different padding methods
302
+ # first treat the audio inputs by simply returning torch tensors
303
+ input_features = [{"input_features": feature["input_features"]}
304
+ for feature in features]
305
+ batch = self.processor.feature_extractor.pad(
306
+ input_features, return_tensors="pt")
307
+
308
+ # get the tokenized label sequences
309
+ label_features = [{"input_ids": feature["labels"]}
310
+ for feature in features]
311
+ # pad the labels to max length
312
+ labels_batch = self.processor.tokenizer.pad(
313
+ label_features, return_tensors="pt")
314
+
315
+ # replace padding with -100 to ignore loss correctly
316
+ labels = labels_batch["input_ids"].masked_fill(
317
+ labels_batch.attention_mask.ne(1), -100)
318
+
319
+ # if bos token is appended in previous tokenization step,
320
+ # cut bos token here as it's append later anyways
321
+ if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
322
+ labels = labels[:, 1:]
323
+
324
+ batch["labels"] = labels
325
+
326
+ return batch
327
+
328
+
329
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
330
+
331
+
332
+ def compute_metrics(pred):
333
+ pdb.set_trace()
334
+ pred_ids = pred.predictions
335
+ label_ids = pred.label_ids
336
+
337
+ # replace -100 with the pad_token_id
338
+ label_ids[label_ids == -100] = tokenizer.pad_token_id
339
+
340
+ # we do not want to group tokens when computing the metrics
341
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
342
+ label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
343
+
344
+ wer = 100 * WER.compute(predictions=pred_str, references=label_str)
345
+
346
+ return {"wer": wer}
347
+
348
+ # encode_negel_79_train = Negel_79_train.map(whisper_prepare_dataset, num_proc=4)
349
+ # encode_negel_79_dev = Negel_79_dev.map(whisper_prepare_dataset, num_proc=4)
350
+ # encode_negel_79_test = Negel_79_test.map(whisper_prepare_dataset, num_proc=4)
351
+ whisper_test = test.map(whisper_prepare_dataset, num_proc=4)
352
+
353
+ encoded_Michael_52_train = Michael_52_train.map(whisper_prepare_dataset, num_proc=4)
354
+ encoded_Michael_52_dev = Michael_52_dev.map(whisper_prepare_dataset, num_proc=4)
355
+ encoded_Michael_52_test = Michael_52_test.map(whisper_prepare_dataset, num_proc=4)
356
+ # negel_152_train, negel_152_dev, negel_152_test = train_dev_test_split(negel_152_dataset, dev_rate=0.1, test_rate=0.1, seed=1)
357
+ # encoded_negel_152_test = negel_152_test.map(whisper_prepare_dataset, num_proc=4)
358
+ pdb.set_trace()
359
+ torch.cuda.empty_cache()
360
+
361
+ torch.cuda.empty_cache()
362
+
363
+ fine_tuned_model = WhisperForConditionalGeneration.from_pretrained(
364
+ fine_tuning_dir
365
+ ).to("cuda")
366
+ # "fine_tuned/SSD/model/whipser_medium_TEP_patient_T"
367
+ # "./fine_tuned/whipser_medium_en_PAL300_step25_step2_VCTK/checkpoint-400"
368
+ #"./fine_tuned/whipser_medium_en_PAL300_step25_step2_VCTK/checkpoint-200"
369
+
370
+
371
+ def fine_tuned_map_to_pred(batch):
372
+ # pdb.set_trace()
373
+ audio = batch["audio"]
374
+ input_features = processor(
375
+ audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt").input_features
376
+ # batch["reference"] = whisper_processor.tokenizer._normalize(batch['text'])
377
+ batch["reference"] = processor.tokenizer._normalize(batch['transcription'])
378
+
379
+ with torch.no_grad():
380
+ # predicted_ids = whisper_model.generate(input_features.to("cuda"))[0]
381
+ predicted_ids = fine_tuned_model.generate(input_features.to("cuda"))[0]
382
+ transcription = tokenizer.decode(predicted_ids)
383
+ batch["prediction"] = tokenizer._normalize(transcription)
384
+ return batch
385
+
386
+
387
+ # output_dir="./fine_tuned/whipser_medium_en_PAL300_step25_step2_VCTK/checkpoint-400",
388
+ # output_dir="fine_tuned/SSD/model/whipser_medium_TEP_patient_TL_TL",
389
+ testing_args = Seq2SeqTrainingArguments(
390
+ # change to a repo name of your choice
391
+ output_dir=fine_tuning_dir,
392
+ per_device_train_batch_size=8,
393
+ gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
394
+ learning_rate=1e-5,
395
+ warmup_steps=100,
396
+ max_steps=1000,
397
+ gradient_checkpointing=True,
398
+ fp16=True,
399
+ evaluation_strategy="steps",
400
+ per_device_eval_batch_size=8,
401
+ predict_with_generate=True,
402
+ generation_max_length=512,
403
+ save_steps=100,
404
+ eval_steps=25,
405
+ logging_steps=100,
406
+ report_to=["tensorboard"],
407
+ load_best_model_at_end=True,
408
+ metric_for_best_model="wer",
409
+ greater_is_better=False,
410
+ push_to_hub=False,
411
+ )
412
+
413
+ predict_trainer = Seq2SeqTrainer(
414
+ args=testing_args,
415
+ model=fine_tuned_model,
416
+ data_collator=data_collator,
417
+ compute_metrics=compute_metrics,
418
+ tokenizer=processor.feature_extractor,
419
+ )
420
+
421
+ # trainer.train()
422
+ # fine tuned
423
+ # z_result = encoded_test.map(fine_tuned_map_to_pred)
424
+
425
+ pdb.set_trace()
426
+ encoded_Michael_test_result = encoded_Michael_52_test.map(fine_tuned_map_to_pred)
427
+ z_M = WER.compute(references=encoded_Michael_test_result['reference'], predictions=encoded_Michael_test_result['prediction'])
428
+ pdb.set_trace()
429
+ encoded_Tony_test_result = encoded_Tony_test.map(fine_tuned_map_to_pred)
430
+ z = WER.compute(references=encoded_Tony_test_result['reference'], predictions=encoded_Tony_test_result['prediction'])
431
+ pdb.set_trace()
432
+
433
+ z_result= test.map(fine_tuned_map_to_pred)
434
+ # 0.4692737430167598
435
+ z = WER.compute(references=z_result['reference'], predictions=z_result['prediction'])
436
+ # pdb.set_trace()
437
+ z_hel_result = encoded_healthy.map(fine_tuned_map_to_pred)
438
+ z_hel = WER.compute(references=z_hel_result['reference'], predictions=z_hel_result['prediction'])
439
+ # 0.1591610117211598
440
+
441
+ # encoded_negel_152_test
442
+ # encoded_negel_test_result = encoded_negel_152_test.map(fine_tuned_map_to_pred)
443
+ # z_negel = WER.compute(references=encoded_negel_test_result['reference'], predictions=encoded_negel_test_result['prediction'])
444
+
445
+ pdb.set_trace()
446
+ z_fary_result = encoded_Fary.map(fine_tuned_map_to_pred)
447
+ z_far = WER.compute(references=z_fary_result['reference'], predictions=z_fary_result['prediction'])
448
+ # 0.1791044776119403
449
+ # z_patient_LT = encoded_patient_TL_test.map(fine_tuned_map_to_pred)
450
+ # z_patient_LT_result = WER.compute(references=z_patient_LT['reference'], predictions=z_patient_LT['prediction'])
451
+ # z_patient_L = encoded_patient_L_test.map(fine_tuned_map_to_pred)
452
+ # z_patient_L_result = WER.compute(references=z_patient_L['reference'], predictions=z_patient_L['prediction'])
453
+ # z_patient_T = encoded_patient_T_test.map(fine_tuned_map_to_pred)
454
+ # z_patient_T_result = WER.compute(references=z_patient_T['reference'], predictions=z_patient_T['prediction'])
455
+
456
+ z_john_p326_result = encoded_John_p326.map(fine_tuned_map_to_pred)
457
+ z_john_p326 = WER.compute(references=z_john_p326_result['reference'], predictions=z_john_p326_result['prediction'])
458
+
459
+ pdb.set_trace()
460
+
461
+ # # z_john_p326 = WER.compute(references=z_john_p326_result['reference'], predictions=z_john_p326_result['prediction'])
462
+ # # 0.4648241206030151
463
+ pdb.set_trace()
464
+
465
+ # # y_John_video= fine_tuned_trainer.predict(encoded_John_video)
466
+ # # metrics={'test_loss': 2.665189743041992, 'test_wer': 0.7222222222222222, 'test_runtime': 0.1633, 'test_samples_per_second': 48.979, 'test_steps_per_second': 6.122})
467
+ # pdb.set_trace()
468
+
469
+ # p326 training
470
+ # metrics={'test_loss': 0.4804028868675232, 'test_wer': 0.21787709497206703, 'test_runtime': 0.3594, 'test_samples_per_second': 44.517, 'test_steps_per_second': 5.565})
471
+ # hel metrics={'test_loss': 1.6363693475723267, 'test_wer': 0.17951881554595928, 'test_runtime': 3.8451, 'test_samples_per_second': 41.611, 'test_steps_per_second': 5.201})
472
+ # Fary: metrics={'t est_loss': 1.4633615016937256, 'test_wer': 0.5572139303482587, 'test_runtime': 0.6627, 'test_samples_per_second': 45.27, 'test_steps_per_second': 6.036})
473
+ # p326 large: metrics={'test_loss': 0.6568527817726135, 'test_wer': 0.2889447236180904, 'test_runtime': 0.7169, 'test_samples_per_second': 51.613, 'test_steps_per_second': 6.975})
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ festival
2
+ espeak # or espeak-ng on Linux
requirements.txt ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.0.0
2
+ aiohttp==3.8.1
3
+ aiosignal==1.2.0
4
+ analytics-python==1.4.0
5
+ antlr4-python3-runtime==4.8
6
+ anyio==3.5.0
7
+ asgiref==3.5.0
8
+ async-timeout==4.0.2
9
+ attrs==21.4.0
10
+ backoff==1.10.0
11
+ bcrypt==3.2.0
12
+ bitarray==2.4.0
13
+ cachetools==5.0.0
14
+ certifi==2021.10.8
15
+ cffi==1.15.0
16
+ charset-normalizer==2.0.12
17
+ click==8.0.4
18
+ colorama==0.4.4
19
+ cryptography==36.0.1
20
+ cycler==0.11.0
21
+ Cython==0.29.28
22
+ fairseq @ git+https://github.com/pytorch/fairseq.git@d03f4e771484a433f025f47744017c2eb6e9c6bc
23
+ fastapi==0.75.0
24
+ ffmpy==0.3.0
25
+ fonttools==4.30.0
26
+ frozenlist==1.3.0
27
+ fsspec==2022.2.0
28
+ future==0.18.2
29
+ google-auth==2.6.0
30
+ google-auth-oauthlib==0.4.6
31
+ gradio==3.40.0
32
+ grpcio==1.44.0
33
+ h11==0.12.0
34
+ hydra-core==1.0.7
35
+ idna==3.3
36
+ importlib-metadata==4.11.3
37
+ Jinja2==3.0.3
38
+ kiwisolver==1.3.2
39
+ linkify-it-py==1.0.3
40
+ Markdown==3.3.6
41
+ markdown-it-py==2.0.1
42
+ MarkupSafe==2.1.0
43
+ matplotlib==3.5.1
44
+ mdit-py-plugins==0.3.0
45
+ mdurl==0.1.0
46
+ monotonic==1.6
47
+ multidict==6.0.2
48
+ numpy==1.22.3
49
+ oauthlib==3.2.0
50
+ omegaconf==2.0.6
51
+ orjson==3.6.7
52
+ packaging==21.3
53
+ pandas==1.4.1
54
+ paramiko==2.10.1
55
+ Pillow==9.0.1
56
+ portalocker==2.4.0
57
+ protobuf==3.19.4
58
+ pyasn1==0.4.8
59
+ pyasn1-modules==0.2.8
60
+ pycparser==2.21
61
+ pycryptodome==3.14.1
62
+ pydantic==1.9.0
63
+ pyDeprecate==0.3.1
64
+ pydub==0.25.1
65
+ PyNaCl==1.5.0
66
+ pyparsing==3.0.7
67
+ python-dateutil==2.8.2
68
+ python-multipart==0.0.5
69
+ pytorch-lightning==1.5.10
70
+ pytz==2021.3
71
+ PyYAML==6.0
72
+ regex==2022.3.2
73
+ requests==2.27.1
74
+ requests-oauthlib==1.3.1
75
+ rsa==4.8
76
+ sacrebleu==2.0.0
77
+ six==1.16.0
78
+ sniffio==1.2.0
79
+ starlette==0.17.1
80
+ tabulate==0.8.9
81
+ tensorboard==2.8.0
82
+ tensorboard-data-server==0.6.1
83
+ tensorboard-plugin-wit==1.8.1
84
+ torch==1.12.1
85
+ torchaudio==0.12.1
86
+ torchmetrics==0.7.2
87
+ tqdm==4.63.0
88
+ typing-extensions==4.1.1
89
+ uc-micro-py==1.0.1
90
+ urllib3==1.26.8
91
+ uvicorn==0.17.6
92
+ Werkzeug==2.0.3
93
+ yarl==1.7.2
94
+ zipp==3.7.0
95
+
96
+ transformers
97
+ deepspeech
98
+ tensorboardX
99
+ jiwer
100
+ phonemizer
101
+ librosa
102
+
103
+ rich
src/description.html ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p>This is the experiment page for Laronix Data Recording.<br>
2
+ <br>
3
+ 1. Select one example from below, a sound file, with its reference transcription and its speaking rate will be loaded as inputs.<br>
4
+ You can check the sound file first and prepare for reading the transcription at a similar tempo.<br>
5
+ 2. Delete the sound file (click the X button on the right), a recording button will appear.<br>
6
+ 3. Click the recording button to start, click again to stop. Make sure you are not mispronouncing or including any detectable noises.<br>
7
+ 4. Click &quot;Submit&quot; button and wait for the result.<br>
8
+ 5. Please check the message box to see the feedback, if ERROR appears, delete your previous recording and try again :).<br>
9
+ 6. If &quot;GOOD JOB!&quot; message appears, click &quot;Flag as Perfect&quot; and start another recording.<br>
10
+ 7. If you try several times (N >= 10) and still can not clear the mission, you can flag your best recording by clicking &quot;Doubtful Speaking Rate&quot; or &quot;Doubtful Naturalness&quot;. <br>
11
+ Yet this seldom happens, so please try to meet the system's requirement first!<br>
12
+ 8. If you have any other question, Please contact kevin@laronix.com </p>
13
+ <img src="https://static.wixstatic.com/media/e7e144_93e98148d06147828031797eb4525b80~mv2.png/v1/crop/x_0,y_25,w_2606,h_882/fill/w_396,h_142,al_c,q_85,usm_0.66_1.00_0.01,enc_auto/newlogo.png" align="right" height="20%" width="20%">
src/lightning_module.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("src")
3
+ import pytorch_lightning as pl
4
+ import torch
5
+ import torch.nn as nn
6
+ import os
7
+ import numpy as np
8
+ import hydra
9
+ from model import load_ssl_model, PhonemeEncoder, DomainEmbedding, LDConditioner, Projection
10
+
11
+
12
+ class BaselineLightningModule(pl.LightningModule):
13
+ def __init__(self, cfg):
14
+ super().__init__()
15
+ self.cfg = cfg
16
+ self.construct_model()
17
+ self.save_hyperparameters()
18
+
19
+ def construct_model(self):
20
+ self.feature_extractors = nn.ModuleList([
21
+ load_ssl_model(cp_path='src/wav2vec_small.pt'),
22
+ DomainEmbedding(3,128),
23
+ ])
24
+ output_dim = sum([ feature_extractor.get_output_dim() for feature_extractor in self.feature_extractors])
25
+ output_layers = [
26
+ LDConditioner(judge_dim=128,num_judges=3000,input_dim=output_dim)
27
+ ]
28
+ output_dim = output_layers[-1].get_output_dim()
29
+ output_layers.append(
30
+ Projection(hidden_dim=2048,activation=torch.nn.ReLU(),range_clipping=False,input_dim=output_dim)
31
+
32
+ )
33
+
34
+ self.output_layers = nn.ModuleList(output_layers)
35
+
36
+ def forward(self, inputs):
37
+ outputs = {}
38
+ for feature_extractor in self.feature_extractors:
39
+ outputs.update(feature_extractor(inputs))
40
+ x = outputs
41
+ for output_layer in self.output_layers:
42
+ x = output_layer(x,inputs)
43
+ return x
src/model.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import fairseq
4
+ import os
5
+ import hydra
6
+
7
+ def load_ssl_model(cp_path):
8
+ ssl_model_type = cp_path.split("/")[-1]
9
+ wavlm = "WavLM" in ssl_model_type
10
+ if wavlm:
11
+ checkpoint = torch.load(cp_path)
12
+ cfg = WavLMConfig(checkpoint['cfg'])
13
+ ssl_model = WavLM(cfg)
14
+ ssl_model.load_state_dict(checkpoint['model'])
15
+ if 'Large' in ssl_model_type:
16
+ SSL_OUT_DIM = 1024
17
+ else:
18
+ SSL_OUT_DIM = 768
19
+ else:
20
+ if ssl_model_type == "wav2vec_small.pt":
21
+ SSL_OUT_DIM = 768
22
+ elif ssl_model_type in ["w2v_large_lv_fsh_swbd_cv.pt", "xlsr_53_56k.pt"]:
23
+ SSL_OUT_DIM = 1024
24
+ else:
25
+ print("*** ERROR *** SSL model type " + ssl_model_type + " not supported.")
26
+ exit()
27
+ model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
28
+ [cp_path]
29
+ )
30
+ ssl_model = model[0]
31
+ ssl_model.remove_pretraining_modules()
32
+ return SSL_model(ssl_model, SSL_OUT_DIM, wavlm)
33
+
34
+ class SSL_model(nn.Module):
35
+ def __init__(self,ssl_model,ssl_out_dim,wavlm) -> None:
36
+ super(SSL_model,self).__init__()
37
+ self.ssl_model, self.ssl_out_dim = ssl_model, ssl_out_dim
38
+ self.WavLM = wavlm
39
+
40
+ def forward(self,batch):
41
+ wav = batch['wav']
42
+ wav = wav.squeeze(1) # [batches, audio_len]
43
+ if self.WavLM:
44
+ x = self.ssl_model.extract_features(wav)[0]
45
+ else:
46
+ res = self.ssl_model(wav, mask=False, features_only=True)
47
+ x = res["x"]
48
+ return {"ssl-feature":x}
49
+ def get_output_dim(self):
50
+ return self.ssl_out_dim
51
+
52
+
53
+ class PhonemeEncoder(nn.Module):
54
+ '''
55
+ PhonemeEncoder consists of an embedding layer, an LSTM layer, and a linear layer.
56
+ Args:
57
+ vocab_size: the size of the vocabulary
58
+ hidden_dim: the size of the hidden state of the LSTM
59
+ emb_dim: the size of the embedding layer
60
+ out_dim: the size of the output of the linear layer
61
+ n_lstm_layers: the number of LSTM layers
62
+ '''
63
+ def __init__(self, vocab_size, hidden_dim, emb_dim, out_dim,n_lstm_layers,with_reference=True) -> None:
64
+ super().__init__()
65
+ self.with_reference = with_reference
66
+ self.embedding = nn.Embedding(vocab_size, emb_dim)
67
+ self.encoder = nn.LSTM(emb_dim, hidden_dim,
68
+ num_layers=n_lstm_layers, dropout=0.1, bidirectional=True)
69
+ self.linear = nn.Sequential(
70
+ nn.Linear(hidden_dim + hidden_dim*self.with_reference, out_dim),
71
+ nn.ReLU()
72
+ )
73
+ self.out_dim = out_dim
74
+
75
+ def forward(self,batch):
76
+ seq = batch['phonemes']
77
+ lens = batch['phoneme_lens']
78
+ reference_seq = batch['reference']
79
+ reference_lens = batch['reference_lens']
80
+ emb = self.embedding(seq)
81
+ emb = torch.nn.utils.rnn.pack_padded_sequence(
82
+ emb, lens, batch_first=True, enforce_sorted=False)
83
+ _, (ht, _) = self.encoder(emb)
84
+ feature = ht[-1] + ht[0]
85
+ if self.with_reference:
86
+ if reference_seq==None or reference_lens ==None:
87
+ raise ValueError("reference_batch and reference_lens should not be None when with_reference is True")
88
+ reference_emb = self.embedding(reference_seq)
89
+ reference_emb = torch.nn.utils.rnn.pack_padded_sequence(
90
+ reference_emb, reference_lens, batch_first=True, enforce_sorted=False)
91
+ _, (ht_ref, _) = self.encoder(emb)
92
+ reference_feature = ht_ref[-1] + ht_ref[0]
93
+ feature = self.linear(torch.cat([feature,reference_feature],1))
94
+ else:
95
+ feature = self.linear(feature)
96
+ return {"phoneme-feature": feature}
97
+ def get_output_dim(self):
98
+ return self.out_dim
99
+
100
+ class DomainEmbedding(nn.Module):
101
+ def __init__(self,n_domains,domain_dim) -> None:
102
+ super().__init__()
103
+ self.embedding = nn.Embedding(n_domains,domain_dim)
104
+ self.output_dim = domain_dim
105
+ def forward(self, batch):
106
+ return {"domain-feature": self.embedding(batch['domains'])}
107
+ def get_output_dim(self):
108
+ return self.output_dim
109
+
110
+
111
+ class LDConditioner(nn.Module):
112
+ '''
113
+ Conditions ssl output by listener embedding
114
+ '''
115
+ def __init__(self,input_dim, judge_dim, num_judges=None):
116
+ super().__init__()
117
+ self.input_dim = input_dim
118
+ self.judge_dim = judge_dim
119
+ self.num_judges = num_judges
120
+ assert num_judges !=None
121
+ self.judge_embedding = nn.Embedding(num_judges, self.judge_dim)
122
+ # concat [self.output_layer, phoneme features]
123
+
124
+ self.decoder_rnn = nn.LSTM(
125
+ input_size = self.input_dim + self.judge_dim,
126
+ hidden_size = 512,
127
+ num_layers = 1,
128
+ batch_first = True,
129
+ bidirectional = True
130
+ ) # linear?
131
+ self.out_dim = self.decoder_rnn.hidden_size*2
132
+
133
+ def get_output_dim(self):
134
+ return self.out_dim
135
+
136
+
137
+ def forward(self, x, batch):
138
+ judge_ids = batch['judge_id']
139
+ if 'phoneme-feature' in x.keys():
140
+ concatenated_feature = torch.cat((x['ssl-feature'], x['phoneme-feature'].unsqueeze(1).expand(-1,x['ssl-feature'].size(1) ,-1)),dim=2)
141
+ else:
142
+ concatenated_feature = x['ssl-feature']
143
+ if 'domain-feature' in x.keys():
144
+ concatenated_feature = torch.cat(
145
+ (
146
+ concatenated_feature,
147
+ x['domain-feature']
148
+ .unsqueeze(1)
149
+ .expand(-1, concatenated_feature.size(1), -1),
150
+ ),
151
+ dim=2,
152
+ )
153
+ if judge_ids != None:
154
+ concatenated_feature = torch.cat(
155
+ (
156
+ concatenated_feature,
157
+ self.judge_embedding(judge_ids)
158
+ .unsqueeze(1)
159
+ .expand(-1, concatenated_feature.size(1), -1),
160
+ ),
161
+ dim=2,
162
+ )
163
+ decoder_output, (h, c) = self.decoder_rnn(concatenated_feature)
164
+ return decoder_output
165
+
166
+ class Projection(nn.Module):
167
+ def __init__(self, input_dim, hidden_dim, activation, range_clipping=False):
168
+ super(Projection, self).__init__()
169
+ self.range_clipping = range_clipping
170
+ output_dim = 1
171
+ if range_clipping:
172
+ self.proj = nn.Tanh()
173
+
174
+ self.net = nn.Sequential(
175
+ nn.Linear(input_dim, hidden_dim),
176
+ activation,
177
+ nn.Dropout(0.3),
178
+ nn.Linear(hidden_dim, output_dim),
179
+ )
180
+ self.output_dim = output_dim
181
+
182
+ def forward(self, x, batch):
183
+ output = self.net(x)
184
+
185
+ # range clipping
186
+ if self.range_clipping:
187
+ return self.proj(output) * 2.0 + 3
188
+ else:
189
+ return output
190
+ def get_output_dim(self):
191
+ return self.output_dim