antoniomae1234 commited on
Commit
2493d72
1 Parent(s): 945170a

changes in flenema

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .cardboardlint.yml +5 -0
  2. .circleci/config.yml +53 -0
  3. .compute +17 -0
  4. .dockerignore +1 -0
  5. .github/ISSUE_TEMPLATE.md +19 -0
  6. .github/PR_TEMPLATE.md +18 -0
  7. .github/stale.yml +19 -0
  8. .gitignore +132 -0
  9. .pylintrc +586 -0
  10. CODE_OF_CONDUCT.md +19 -0
  11. CODE_OWNERS.rst +75 -0
  12. CONTRIBUTING.md +51 -0
  13. LICENSE.txt +373 -0
  14. MANIFEST.in +11 -0
  15. README.md +281 -3
  16. TTS/.models.json +77 -0
  17. TTS/__init__.py +0 -0
  18. TTS/bin/__init__.py +0 -0
  19. TTS/bin/compute_attention_masks.py +166 -0
  20. TTS/bin/compute_embeddings.py +130 -0
  21. TTS/bin/compute_statistics.py +90 -0
  22. TTS/bin/convert_melgan_tflite.py +32 -0
  23. TTS/bin/convert_melgan_torch_to_tf.py +116 -0
  24. TTS/bin/convert_tacotron2_tflite.py +37 -0
  25. TTS/bin/convert_tacotron2_torch_to_tf.py +213 -0
  26. TTS/bin/distribute.py +69 -0
  27. TTS/bin/synthesize.py +218 -0
  28. TTS/bin/train_encoder.py +274 -0
  29. TTS/bin/train_glow_tts.py +657 -0
  30. TTS/bin/train_speedy_speech.py +618 -0
  31. TTS/bin/train_tacotron.py +731 -0
  32. TTS/bin/train_vocoder_gan.py +664 -0
  33. TTS/bin/train_vocoder_wavegrad.py +511 -0
  34. TTS/bin/train_vocoder_wavernn.py +539 -0
  35. TTS/bin/tune_wavegrad.py +91 -0
  36. TTS/server/README.md +65 -0
  37. TTS/server/__init__.py +0 -0
  38. TTS/server/conf.json +12 -0
  39. TTS/server/server.py +116 -0
  40. TTS/server/static/TTS_circle.png +0 -0
  41. TTS/server/templates/details.html +131 -0
  42. TTS/server/templates/index.html +114 -0
  43. TTS/speaker_encoder/README.md +18 -0
  44. TTS/speaker_encoder/__init__.py +0 -0
  45. TTS/speaker_encoder/config.json +103 -0
  46. TTS/speaker_encoder/dataset.py +169 -0
  47. TTS/speaker_encoder/losses.py +160 -0
  48. TTS/speaker_encoder/model.py +112 -0
  49. TTS/speaker_encoder/requirements.txt +2 -0
  50. TTS/speaker_encoder/umap.png +0 -0
.cardboardlint.yml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ linters:
2
+ - pylint:
3
+ # pylintrc: pylintrc
4
+ filefilter: ['- test_*.py', '+ *.py', '- *.npy']
5
+ # exclude:
.circleci/config.yml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: 2
2
+
3
+ workflows:
4
+ version: 2
5
+ test:
6
+ jobs:
7
+ - test-3.6
8
+ - test-3.7
9
+ - test-3.8
10
+
11
+ executor: ubuntu-latest
12
+
13
+ on:
14
+ push:
15
+ pull_request:
16
+ types: [opened, synchronize, reopened]
17
+
18
+ jobs:
19
+ test-3.6: &test-template
20
+ docker:
21
+ - image: circleci/python:3.6
22
+ resource_class: large
23
+ working_directory: ~/repo
24
+ steps:
25
+ - checkout
26
+ - run: |
27
+ sudo apt update
28
+ sudo apt install espeak git
29
+ - run: sudo pip install --upgrade pip
30
+ - run: sudo pip install -e .
31
+ - run: |
32
+ sudo pip install --quiet --upgrade cardboardlint pylint
33
+ cardboardlinter --refspec ${CIRCLE_BRANCH} -n auto
34
+ - run: nosetests tests --nocapture
35
+ - run: |
36
+ sudo ./tests/test_server_package.sh
37
+ sudo ./tests/test_glow-tts_train.sh
38
+ sudo ./tests/test_server_package.sh
39
+ sudo ./tests/test_tacotron_train.sh
40
+ sudo ./tests/test_vocoder_gan_train.sh
41
+ sudo ./tests/test_vocoder_wavegrad_train.sh
42
+ sudo ./tests/test_vocoder_wavernn_train.sh
43
+ sudo ./tests/test_speedy_speech_train.sh
44
+
45
+ test-3.7:
46
+ <<: *test-template
47
+ docker:
48
+ - image: circleci/python:3.7
49
+
50
+ test-3.8:
51
+ <<: *test-template
52
+ docker:
53
+ - image: circleci/python:3.8
.compute ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ yes | apt-get install sox
3
+ yes | apt-get install ffmpeg
4
+ yes | apt-get install espeak
5
+ yes | apt-get install tmux
6
+ yes | apt-get install zsh
7
+ sh -c "$(curl -fsSL https://raw.githubusercontent.com/robbyrussell/oh-my-zsh/master/tools/install.sh)"
8
+ pip3 install https://download.pytorch.org/whl/cu100/torch-1.3.0%2Bcu100-cp36-cp36m-linux_x86_64.whl
9
+ sudo sh install.sh
10
+ # pip install pytorch==1.7.0+cu100
11
+ # python3 setup.py develop
12
+ # python3 distribute.py --config_path config.json --data_path /data/ro/shared/data/keithito/LJSpeech-1.1/
13
+ # cp -R ${USER_DIR}/Mozilla_22050 ../tmp/
14
+ # python3 distribute.py --config_path config_tacotron_gst.json --data_path ../tmp/Mozilla_22050/
15
+ # python3 distribute.py --config_path config.json --data_path /data/rw/home/LibriTTS/train-clean-360
16
+ # python3 distribute.py --config_path config.json
17
+ while true; do sleep 1000000; done
.dockerignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .git/
.github/ISSUE_TEMPLATE.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: 'TTS Discourse '
3
+ about: Pls consider to use TTS Discourse page.
4
+ title: ''
5
+ labels: ''
6
+ assignees: ''
7
+
8
+ ---
9
+ <b>Questions</b> will not be answered here!!
10
+
11
+ Help is much more valuable if it's shared publicly, so that more people can benefit from it.
12
+
13
+ Please consider posting on [TTS Discourse](https://discourse.mozilla.org/c/tts) page or matrix [chat room](https://matrix.to/#/!KTePhNahjgiVumkqca:matrix.org?via=matrix.org) if your issue is not directly related to TTS development (Bugs, code updates etc.).
14
+
15
+ You can also check https://github.com/mozilla/TTS/wiki/FAQ for common questions and answers.
16
+
17
+ Happy posting!
18
+
19
+ https://discourse.mozilla.org/c/tts
.github/PR_TEMPLATE.md ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: 'Contribution Guideline '
3
+ about: Refer to Contirbution Guideline
4
+ title: ''
5
+ labels: ''
6
+ assignees: ''
7
+
8
+ ---
9
+ ### Contribution Guideline
10
+
11
+ Please send your PRs to `dev` branch if it is not directly related to a specific branch.
12
+ Before making a Pull Request, check your changes for basic mistakes and style problems by using a linter.
13
+ We have cardboardlinter setup in this repository, so for example, if you've made some changes and would like to run the linter on just the changed code, you can use the follow command:
14
+
15
+ ```bash
16
+ pip install pylint cardboardlint
17
+ cardboardlinter --refspec master
18
+ ```
.github/stale.yml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Number of days of inactivity before an issue becomes stale
2
+ daysUntilStale: 60
3
+ # Number of days of inactivity before a stale issue is closed
4
+ daysUntilClose: 7
5
+ # Issues with these labels will never be considered stale
6
+ exemptLabels:
7
+ - pinned
8
+ - security
9
+ # Label to use when marking an issue as stale
10
+ staleLabel: wontfix
11
+ # Comment to post when marking an issue as stale. Set to `false` to disable
12
+ markComment: >
13
+ This issue has been automatically marked as stale because it has not had
14
+ recent activity. It will be closed if no further activity occurs. Thank you
15
+ for your contributions. You might also look our discourse page for further help.
16
+ https://discourse.mozilla.org/c/tts
17
+ # Comment to post when closing a stale issue. Set to `false` to disable
18
+ closeComment: false
19
+
.gitignore ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ WadaSNR/
2
+ .idea/
3
+ *.pyc
4
+ .DS_Store
5
+ ./__init__.py
6
+ # Byte-compiled / optimized / DLL files
7
+ __pycache__/
8
+ *.py[cod]
9
+ *$py.class
10
+
11
+ # C extensions
12
+ *.so
13
+
14
+ # Distribution / packaging
15
+ .Python
16
+ build/
17
+ develop-eggs/
18
+ dist/
19
+ downloads/
20
+ eggs/
21
+ .eggs/
22
+ lib/
23
+ lib64/
24
+ parts/
25
+ sdist/
26
+ var/
27
+ wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ .hypothesis/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ .static_storage/
61
+ .media/
62
+ local_settings.py
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # pyenv
81
+ .python-version
82
+
83
+ # celery beat schedule file
84
+ celerybeat-schedule
85
+
86
+ # SageMath parsed files
87
+ *.sage.py
88
+
89
+ # Environments
90
+ .env
91
+ .venv
92
+ env/
93
+ venv/
94
+ ENV/
95
+ env.bak/
96
+ venv.bak/
97
+
98
+ # Spyder project settings
99
+ .spyderproject
100
+ .spyproject
101
+
102
+ # Rope project settings
103
+ .ropeproject
104
+
105
+ # mkdocs documentation
106
+ /site
107
+
108
+ # mypy
109
+ .mypy_cache/
110
+
111
+ # vim
112
+ *.swp
113
+ *.swm
114
+ *.swn
115
+ *.swo
116
+
117
+ # pytorch models
118
+ *.pth.tar
119
+ result/
120
+
121
+ # setup.py
122
+ version.py
123
+
124
+ # jupyter dummy files
125
+ core
126
+
127
+ tests/outputs/*
128
+ TODO.txt
129
+ .vscode/*
130
+ data/*
131
+ notebooks/data/*
132
+ TTS/tts/layers/glow_tts/monotonic_align/core.c
.pylintrc ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [MASTER]
2
+
3
+ # A comma-separated list of package or module names from where C extensions may
4
+ # be loaded. Extensions are loading into the active Python interpreter and may
5
+ # run arbitrary code.
6
+ extension-pkg-whitelist=
7
+
8
+ # Add files or directories to the blacklist. They should be base names, not
9
+ # paths.
10
+ ignore=CVS
11
+
12
+ # Add files or directories matching the regex patterns to the blacklist. The
13
+ # regex matches against base names, not paths.
14
+ ignore-patterns=
15
+
16
+ # Python code to execute, usually for sys.path manipulation such as
17
+ # pygtk.require().
18
+ #init-hook=
19
+
20
+ # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
21
+ # number of processors available to use.
22
+ jobs=1
23
+
24
+ # Control the amount of potential inferred values when inferring a single
25
+ # object. This can help the performance when dealing with large functions or
26
+ # complex, nested conditions.
27
+ limit-inference-results=100
28
+
29
+ # List of plugins (as comma separated values of python modules names) to load,
30
+ # usually to register additional checkers.
31
+ load-plugins=
32
+
33
+ # Pickle collected data for later comparisons.
34
+ persistent=yes
35
+
36
+ # Specify a configuration file.
37
+ #rcfile=
38
+
39
+ # When enabled, pylint would attempt to guess common misconfiguration and emit
40
+ # user-friendly hints instead of false-positive error messages.
41
+ suggestion-mode=yes
42
+
43
+ # Allow loading of arbitrary C extensions. Extensions are imported into the
44
+ # active Python interpreter and may run arbitrary code.
45
+ unsafe-load-any-extension=no
46
+
47
+
48
+ [MESSAGES CONTROL]
49
+
50
+ # Only show warnings with the listed confidence levels. Leave empty to show
51
+ # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED.
52
+ confidence=
53
+
54
+ # Disable the message, report, category or checker with the given id(s). You
55
+ # can either give multiple identifiers separated by comma (,) or put this
56
+ # option multiple times (only on the command line, not in the configuration
57
+ # file where it should appear only once). You can also use "--disable=all" to
58
+ # disable everything first and then reenable specific checks. For example, if
59
+ # you want to run only the similarities checker, you can use "--disable=all
60
+ # --enable=similarities". If you want to run only the classes checker, but have
61
+ # no Warning level messages displayed, use "--disable=all --enable=classes
62
+ # --disable=W".
63
+ disable=missing-docstring,
64
+ line-too-long,
65
+ fixme,
66
+ wrong-import-order,
67
+ ungrouped-imports,
68
+ wrong-import-position,
69
+ import-error,
70
+ invalid-name,
71
+ too-many-instance-attributes,
72
+ arguments-differ,
73
+ no-name-in-module,
74
+ no-member,
75
+ unsubscriptable-object,
76
+ print-statement,
77
+ parameter-unpacking,
78
+ unpacking-in-except,
79
+ old-raise-syntax,
80
+ backtick,
81
+ long-suffix,
82
+ old-ne-operator,
83
+ old-octal-literal,
84
+ import-star-module-level,
85
+ non-ascii-bytes-literal,
86
+ raw-checker-failed,
87
+ bad-inline-option,
88
+ locally-disabled,
89
+ file-ignored,
90
+ suppressed-message,
91
+ useless-suppression,
92
+ deprecated-pragma,
93
+ use-symbolic-message-instead,
94
+ useless-object-inheritance,
95
+ too-few-public-methods,
96
+ too-many-branches,
97
+ too-many-arguments,
98
+ too-many-locals,
99
+ too-many-statements,
100
+ apply-builtin,
101
+ basestring-builtin,
102
+ buffer-builtin,
103
+ cmp-builtin,
104
+ coerce-builtin,
105
+ execfile-builtin,
106
+ file-builtin,
107
+ long-builtin,
108
+ raw_input-builtin,
109
+ reduce-builtin,
110
+ standarderror-builtin,
111
+ unicode-builtin,
112
+ xrange-builtin,
113
+ coerce-method,
114
+ delslice-method,
115
+ getslice-method,
116
+ setslice-method,
117
+ no-absolute-import,
118
+ old-division,
119
+ dict-iter-method,
120
+ dict-view-method,
121
+ next-method-called,
122
+ metaclass-assignment,
123
+ indexing-exception,
124
+ raising-string,
125
+ reload-builtin,
126
+ oct-method,
127
+ hex-method,
128
+ nonzero-method,
129
+ cmp-method,
130
+ input-builtin,
131
+ round-builtin,
132
+ intern-builtin,
133
+ unichr-builtin,
134
+ map-builtin-not-iterating,
135
+ zip-builtin-not-iterating,
136
+ range-builtin-not-iterating,
137
+ filter-builtin-not-iterating,
138
+ using-cmp-argument,
139
+ eq-without-hash,
140
+ div-method,
141
+ idiv-method,
142
+ rdiv-method,
143
+ exception-message-attribute,
144
+ invalid-str-codec,
145
+ sys-max-int,
146
+ bad-python3-import,
147
+ deprecated-string-function,
148
+ deprecated-str-translate-call,
149
+ deprecated-itertools-function,
150
+ deprecated-types-field,
151
+ next-method-defined,
152
+ dict-items-not-iterating,
153
+ dict-keys-not-iterating,
154
+ dict-values-not-iterating,
155
+ deprecated-operator-function,
156
+ deprecated-urllib-function,
157
+ xreadlines-attribute,
158
+ deprecated-sys-function,
159
+ exception-escape,
160
+ comprehension-escape,
161
+ duplicate-code
162
+
163
+ # Enable the message, report, category or checker with the given id(s). You can
164
+ # either give multiple identifier separated by comma (,) or put this option
165
+ # multiple time (only on the command line, not in the configuration file where
166
+ # it should appear only once). See also the "--disable" option for examples.
167
+ enable=c-extension-no-member
168
+
169
+
170
+ [REPORTS]
171
+
172
+ # Python expression which should return a note less than 10 (10 is the highest
173
+ # note). You have access to the variables errors warning, statement which
174
+ # respectively contain the number of errors / warnings messages and the total
175
+ # number of statements analyzed. This is used by the global evaluation report
176
+ # (RP0004).
177
+ evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
178
+
179
+ # Template used to display messages. This is a python new-style format string
180
+ # used to format the message information. See doc for all details.
181
+ #msg-template=
182
+
183
+ # Set the output format. Available formats are text, parseable, colorized, json
184
+ # and msvs (visual studio). You can also give a reporter class, e.g.
185
+ # mypackage.mymodule.MyReporterClass.
186
+ output-format=text
187
+
188
+ # Tells whether to display a full report or only the messages.
189
+ reports=no
190
+
191
+ # Activate the evaluation score.
192
+ score=yes
193
+
194
+
195
+ [REFACTORING]
196
+
197
+ # Maximum number of nested blocks for function / method body
198
+ max-nested-blocks=5
199
+
200
+ # Complete name of functions that never returns. When checking for
201
+ # inconsistent-return-statements if a never returning function is called then
202
+ # it will be considered as an explicit return statement and no message will be
203
+ # printed.
204
+ never-returning-functions=sys.exit
205
+
206
+
207
+ [LOGGING]
208
+
209
+ # Format style used to check logging format string. `old` means using %
210
+ # formatting, while `new` is for `{}` formatting.
211
+ logging-format-style=old
212
+
213
+ # Logging modules to check that the string format arguments are in logging
214
+ # function parameter format.
215
+ logging-modules=logging
216
+
217
+
218
+ [SPELLING]
219
+
220
+ # Limits count of emitted suggestions for spelling mistakes.
221
+ max-spelling-suggestions=4
222
+
223
+ # Spelling dictionary name. Available dictionaries: none. To make it working
224
+ # install python-enchant package..
225
+ spelling-dict=
226
+
227
+ # List of comma separated words that should not be checked.
228
+ spelling-ignore-words=
229
+
230
+ # A path to a file that contains private dictionary; one word per line.
231
+ spelling-private-dict-file=
232
+
233
+ # Tells whether to store unknown words to indicated private dictionary in
234
+ # --spelling-private-dict-file option instead of raising a message.
235
+ spelling-store-unknown-words=no
236
+
237
+
238
+ [MISCELLANEOUS]
239
+
240
+ # List of note tags to take in consideration, separated by a comma.
241
+ notes=FIXME,
242
+ XXX,
243
+ TODO
244
+
245
+
246
+ [TYPECHECK]
247
+
248
+ # List of decorators that produce context managers, such as
249
+ # contextlib.contextmanager. Add to this list to register other decorators that
250
+ # produce valid context managers.
251
+ contextmanager-decorators=contextlib.contextmanager
252
+
253
+ # List of members which are set dynamically and missed by pylint inference
254
+ # system, and so shouldn't trigger E1101 when accessed. Python regular
255
+ # expressions are accepted.
256
+ generated-members=
257
+
258
+ # Tells whether missing members accessed in mixin class should be ignored. A
259
+ # mixin class is detected if its name ends with "mixin" (case insensitive).
260
+ ignore-mixin-members=yes
261
+
262
+ # Tells whether to warn about missing members when the owner of the attribute
263
+ # is inferred to be None.
264
+ ignore-none=yes
265
+
266
+ # This flag controls whether pylint should warn about no-member and similar
267
+ # checks whenever an opaque object is returned when inferring. The inference
268
+ # can return multiple potential results while evaluating a Python object, but
269
+ # some branches might not be evaluated, which results in partial inference. In
270
+ # that case, it might be useful to still emit no-member and other checks for
271
+ # the rest of the inferred objects.
272
+ ignore-on-opaque-inference=yes
273
+
274
+ # List of class names for which member attributes should not be checked (useful
275
+ # for classes with dynamically set attributes). This supports the use of
276
+ # qualified names.
277
+ ignored-classes=optparse.Values,thread._local,_thread._local
278
+
279
+ # List of module names for which member attributes should not be checked
280
+ # (useful for modules/projects where namespaces are manipulated during runtime
281
+ # and thus existing member attributes cannot be deduced by static analysis. It
282
+ # supports qualified module names, as well as Unix pattern matching.
283
+ ignored-modules=
284
+
285
+ # Show a hint with possible names when a member name was not found. The aspect
286
+ # of finding the hint is based on edit distance.
287
+ missing-member-hint=yes
288
+
289
+ # The minimum edit distance a name should have in order to be considered a
290
+ # similar match for a missing member name.
291
+ missing-member-hint-distance=1
292
+
293
+ # The total number of similar names that should be taken in consideration when
294
+ # showing a hint for a missing member.
295
+ missing-member-max-choices=1
296
+
297
+
298
+ [VARIABLES]
299
+
300
+ # List of additional names supposed to be defined in builtins. Remember that
301
+ # you should avoid defining new builtins when possible.
302
+ additional-builtins=
303
+
304
+ # Tells whether unused global variables should be treated as a violation.
305
+ allow-global-unused-variables=yes
306
+
307
+ # List of strings which can identify a callback function by name. A callback
308
+ # name must start or end with one of those strings.
309
+ callbacks=cb_,
310
+ _cb
311
+
312
+ # A regular expression matching the name of dummy variables (i.e. expected to
313
+ # not be used).
314
+ dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
315
+
316
+ # Argument names that match this expression will be ignored. Default to name
317
+ # with leading underscore.
318
+ ignored-argument-names=_.*|^ignored_|^unused_
319
+
320
+ # Tells whether we should check for unused import in __init__ files.
321
+ init-import=no
322
+
323
+ # List of qualified module names which can have objects that can redefine
324
+ # builtins.
325
+ redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
326
+
327
+
328
+ [FORMAT]
329
+
330
+ # Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
331
+ expected-line-ending-format=
332
+
333
+ # Regexp for a line that is allowed to be longer than the limit.
334
+ ignore-long-lines=^\s*(# )?<?https?://\S+>?$
335
+
336
+ # Number of spaces of indent required inside a hanging or continued line.
337
+ indent-after-paren=4
338
+
339
+ # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
340
+ # tab).
341
+ indent-string=' '
342
+
343
+ # Maximum number of characters on a single line.
344
+ max-line-length=100
345
+
346
+ # Maximum number of lines in a module.
347
+ max-module-lines=1000
348
+
349
+ # List of optional constructs for which whitespace checking is disabled. `dict-
350
+ # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
351
+ # `trailing-comma` allows a space between comma and closing bracket: (a, ).
352
+ # `empty-line` allows space-only lines.
353
+ no-space-check=trailing-comma,
354
+ dict-separator
355
+
356
+ # Allow the body of a class to be on the same line as the declaration if body
357
+ # contains single statement.
358
+ single-line-class-stmt=no
359
+
360
+ # Allow the body of an if to be on the same line as the test if there is no
361
+ # else.
362
+ single-line-if-stmt=no
363
+
364
+
365
+ [SIMILARITIES]
366
+
367
+ # Ignore comments when computing similarities.
368
+ ignore-comments=yes
369
+
370
+ # Ignore docstrings when computing similarities.
371
+ ignore-docstrings=yes
372
+
373
+ # Ignore imports when computing similarities.
374
+ ignore-imports=no
375
+
376
+ # Minimum lines number of a similarity.
377
+ min-similarity-lines=4
378
+
379
+
380
+ [BASIC]
381
+
382
+ # Naming style matching correct argument names.
383
+ argument-naming-style=snake_case
384
+
385
+ # Regular expression matching correct argument names. Overrides argument-
386
+ # naming-style.
387
+ argument-rgx=[a-z_][a-z0-9_]{0,30}$
388
+
389
+ # Naming style matching correct attribute names.
390
+ attr-naming-style=snake_case
391
+
392
+ # Regular expression matching correct attribute names. Overrides attr-naming-
393
+ # style.
394
+ #attr-rgx=
395
+
396
+ # Bad variable names which should always be refused, separated by a comma.
397
+ bad-names=
398
+
399
+ # Naming style matching correct class attribute names.
400
+ class-attribute-naming-style=any
401
+
402
+ # Regular expression matching correct class attribute names. Overrides class-
403
+ # attribute-naming-style.
404
+ #class-attribute-rgx=
405
+
406
+ # Naming style matching correct class names.
407
+ class-naming-style=PascalCase
408
+
409
+ # Regular expression matching correct class names. Overrides class-naming-
410
+ # style.
411
+ #class-rgx=
412
+
413
+ # Naming style matching correct constant names.
414
+ const-naming-style=UPPER_CASE
415
+
416
+ # Regular expression matching correct constant names. Overrides const-naming-
417
+ # style.
418
+ #const-rgx=
419
+
420
+ # Minimum line length for functions/classes that require docstrings, shorter
421
+ # ones are exempt.
422
+ docstring-min-length=-1
423
+
424
+ # Naming style matching correct function names.
425
+ function-naming-style=snake_case
426
+
427
+ # Regular expression matching correct function names. Overrides function-
428
+ # naming-style.
429
+ #function-rgx=
430
+
431
+ # Good variable names which should always be accepted, separated by a comma.
432
+ good-names=i,
433
+ j,
434
+ k,
435
+ x,
436
+ ex,
437
+ Run,
438
+ _
439
+
440
+ # Include a hint for the correct naming format with invalid-name.
441
+ include-naming-hint=no
442
+
443
+ # Naming style matching correct inline iteration names.
444
+ inlinevar-naming-style=any
445
+
446
+ # Regular expression matching correct inline iteration names. Overrides
447
+ # inlinevar-naming-style.
448
+ #inlinevar-rgx=
449
+
450
+ # Naming style matching correct method names.
451
+ method-naming-style=snake_case
452
+
453
+ # Regular expression matching correct method names. Overrides method-naming-
454
+ # style.
455
+ #method-rgx=
456
+
457
+ # Naming style matching correct module names.
458
+ module-naming-style=snake_case
459
+
460
+ # Regular expression matching correct module names. Overrides module-naming-
461
+ # style.
462
+ #module-rgx=
463
+
464
+ # Colon-delimited sets of names that determine each other's naming style when
465
+ # the name regexes allow several styles.
466
+ name-group=
467
+
468
+ # Regular expression which should only match function or class names that do
469
+ # not require a docstring.
470
+ no-docstring-rgx=^_
471
+
472
+ # List of decorators that produce properties, such as abc.abstractproperty. Add
473
+ # to this list to register other decorators that produce valid properties.
474
+ # These decorators are taken in consideration only for invalid-name.
475
+ property-classes=abc.abstractproperty
476
+
477
+ # Naming style matching correct variable names.
478
+ variable-naming-style=snake_case
479
+
480
+ # Regular expression matching correct variable names. Overrides variable-
481
+ # naming-style.
482
+ variable-rgx=[a-z_][a-z0-9_]{0,30}$
483
+
484
+
485
+ [STRING]
486
+
487
+ # This flag controls whether the implicit-str-concat-in-sequence should
488
+ # generate a warning on implicit string concatenation in sequences defined over
489
+ # several lines.
490
+ check-str-concat-over-line-jumps=no
491
+
492
+
493
+ [IMPORTS]
494
+
495
+ # Allow wildcard imports from modules that define __all__.
496
+ allow-wildcard-with-all=no
497
+
498
+ # Analyse import fallback blocks. This can be used to support both Python 2 and
499
+ # 3 compatible code, which means that the block might have code that exists
500
+ # only in one or another interpreter, leading to false positives when analysed.
501
+ analyse-fallback-blocks=no
502
+
503
+ # Deprecated modules which should not be used, separated by a comma.
504
+ deprecated-modules=optparse,tkinter.tix
505
+
506
+ # Create a graph of external dependencies in the given file (report RP0402 must
507
+ # not be disabled).
508
+ ext-import-graph=
509
+
510
+ # Create a graph of every (i.e. internal and external) dependencies in the
511
+ # given file (report RP0402 must not be disabled).
512
+ import-graph=
513
+
514
+ # Create a graph of internal dependencies in the given file (report RP0402 must
515
+ # not be disabled).
516
+ int-import-graph=
517
+
518
+ # Force import order to recognize a module as part of the standard
519
+ # compatibility libraries.
520
+ known-standard-library=
521
+
522
+ # Force import order to recognize a module as part of a third party library.
523
+ known-third-party=enchant
524
+
525
+
526
+ [CLASSES]
527
+
528
+ # List of method names used to declare (i.e. assign) instance attributes.
529
+ defining-attr-methods=__init__,
530
+ __new__,
531
+ setUp
532
+
533
+ # List of member names, which should be excluded from the protected access
534
+ # warning.
535
+ exclude-protected=_asdict,
536
+ _fields,
537
+ _replace,
538
+ _source,
539
+ _make
540
+
541
+ # List of valid names for the first argument in a class method.
542
+ valid-classmethod-first-arg=cls
543
+
544
+ # List of valid names for the first argument in a metaclass class method.
545
+ valid-metaclass-classmethod-first-arg=cls
546
+
547
+
548
+ [DESIGN]
549
+
550
+ # Maximum number of arguments for function / method.
551
+ max-args=5
552
+
553
+ # Maximum number of attributes for a class (see R0902).
554
+ max-attributes=7
555
+
556
+ # Maximum number of boolean expressions in an if statement.
557
+ max-bool-expr=5
558
+
559
+ # Maximum number of branch for function / method body.
560
+ max-branches=12
561
+
562
+ # Maximum number of locals for function / method body.
563
+ max-locals=15
564
+
565
+ # Maximum number of parents for a class (see R0901).
566
+ max-parents=7
567
+
568
+ # Maximum number of public methods for a class (see R0904).
569
+ max-public-methods=20
570
+
571
+ # Maximum number of return / yield for function / method body.
572
+ max-returns=6
573
+
574
+ # Maximum number of statements in function / method body.
575
+ max-statements=50
576
+
577
+ # Minimum number of public methods for a class (see R0903).
578
+ min-public-methods=2
579
+
580
+
581
+ [EXCEPTIONS]
582
+
583
+ # Exceptions that will emit a warning when being caught. Defaults to
584
+ # "BaseException, Exception".
585
+ overgeneral-exceptions=BaseException,
586
+ Exception
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ethical Notice
2
+
3
+ Please consider possible consequences and be mindful of any adversarial use cases of this project. In this regard, please contact us if you have any concerns.
4
+
5
+ # Community Participation Guidelines
6
+
7
+ This repository is governed by Mozilla's code of conduct and etiquette guidelines.
8
+ For more details, please read the
9
+ [Mozilla Community Participation Guidelines](https://www.mozilla.org/about/governance/policies/participation/).
10
+
11
+ ## How to Report
12
+ For more information on how to report violations of the Community Participation Guidelines, please read our '[How to Report](https://www.mozilla.org/about/governance/policies/participation/reporting/)' page.
13
+
14
+ <!--
15
+ ## Project Specific Etiquette
16
+
17
+ In some cases, there will be additional project etiquette i.e.: (https://bugzilla.mozilla.org/page.cgi?id=etiquette.html).
18
+ Please update for your project.
19
+ -->
CODE_OWNERS.rst ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TTS code owners / governance system
2
+ ==========================================
3
+
4
+ TTS is run under a governance system inspired (and partially copied from) by the `Mozilla module ownership system <https://www.mozilla.org/about/governance/policies/module-ownership/>`_. The project is roughly divided into modules, and each module has its owners, which are responsible for reviewing pull requests and deciding on technical direction for their modules. Module ownership authority is given to people who have worked extensively on areas of the project.
5
+
6
+ Module owners also have the authority of naming other module owners or appointing module peers, which are people with authority to review pull requests in that module. They can also sub-divide their module into sub-modules with their owners.
7
+
8
+ Module owners are not tyrants. They are chartered to make decisions with input from the community and in the best interest of the community. Module owners are not required to make code changes or additions solely because the community wants them to do so. (Like anyone else, the module owners may write code because they want to, because their employers want them to, because the community wants them to, or for some other reason.) Module owners do need to pay attention to patches submitted to that module. However “pay attention” does not mean agreeing to every patch. Some patches may not make sense for the WebThings project; some may be poorly implemented. Module owners have the authority to decline a patch; this is a necessary part of the role. We ask the module owners to describe in the relevant issue their reasons for wanting changes to a patch, for declining it altogether, or for postponing review for some period. We don’t ask or expect them to rewrite patches to make them acceptable. Similarly, module owners may need to delay review of a promising patch due to an upcoming deadline. For example, a patch may be of interest, but not for the next milestone. In such a case it may make sense for the module owner to postpone review of a patch until after matters needed for a milestone have been finalized. Again, we expect this to be described in the relevant issue. And of course, it shouldn’t go on very often or for very long or escalation and review is likely.
9
+
10
+ The work of the various module owners and peers is overseen by the global owners, which are responsible for making final decisions in case there's conflict between owners as well as set the direction for the project as a whole.
11
+
12
+ This file describes module owners who are active on the project and which parts of the code they have expertise on (and interest in). If you're making changes to the code and are wondering who's an appropriate person to talk to, this list will tell you who to ping.
13
+
14
+ There's overlap in the areas of expertise of each owner, and in particular when looking at which files are covered by each area, there is a lot of overlap. Don't worry about getting it exactly right when requesting review, any code owner will be happy to redirect the request to a more appropriate person.
15
+
16
+ Global owners
17
+ ----------------
18
+
19
+ These are people who have worked on the project extensively and are familiar with all or most parts of it. Their expertise and review guidance is trusted by other code owners to cover their own areas of expertise. In case of conflicting opinions from other owners, global owners will make a final decision.
20
+
21
+ - Eren Gölge (@erogol)
22
+ - Reuben Morais (@reuben)
23
+
24
+ Training, feeding
25
+ -----------------
26
+
27
+ - Eren Gölge (@erogol)
28
+
29
+ Model exporting
30
+ ---------------
31
+
32
+ - Eren Gölge (@erogol)
33
+
34
+ Multi-Speaker TTS
35
+ -----------------
36
+
37
+ - Eren Gölge (@erogol)
38
+ - Edresson Casanova (@edresson)
39
+
40
+ TTS
41
+ ---
42
+
43
+ - Eren Gölge (@erogol)
44
+
45
+ Vocoders
46
+ --------
47
+
48
+ - Eren Gölge (@erogol)
49
+
50
+ Speaker Encoder
51
+ ---------------
52
+
53
+ - Eren Gölge (@erogol)
54
+
55
+ Testing & CI
56
+ ------------
57
+
58
+ - Eren Gölge (@erogol)
59
+ - Reuben Morais (@reuben)
60
+
61
+ Python bindings
62
+ ---------------
63
+
64
+ - Eren Gölge (@erogol)
65
+ - Reuben Morais (@reuben)
66
+
67
+ Documentation
68
+ -------------
69
+
70
+ - Eren Gölge (@erogol)
71
+
72
+ Third party bindings
73
+ --------------------
74
+
75
+ Owned by the author.
CONTRIBUTING.md ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contribution guidelines
2
+
3
+ This repository is governed by Mozilla's code of conduct and etiquette guidelines. For more details, please read the [Mozilla Community Participation Guidelines](https://www.mozilla.org/about/governance/policies/participation/).
4
+
5
+ Before making a Pull Request, check your changes for basic mistakes and style problems by using a linter. We have cardboardlinter setup in this repository, so for example, if you've made some changes and would like to run the linter on just the differences between your work and master, you can use the follow command:
6
+
7
+ ```bash
8
+ pip install pylint cardboardlint
9
+ cardboardlinter --refspec master
10
+ ```
11
+
12
+ This will compare the code against master and run the linter on all the changes. To run it automatically as a git pre-commit hook, you can do do the following:
13
+
14
+ ```bash
15
+ cat <<\EOF > .git/hooks/pre-commit
16
+ #!/bin/bash
17
+ if [ ! -x "$(command -v cardboardlinter)" ]; then
18
+ exit 0
19
+ fi
20
+
21
+ # First, stash index and work dir, keeping only the
22
+ # to-be-committed changes in the working directory.
23
+ echo "Stashing working tree changes..." 1>&2
24
+ old_stash=$(git rev-parse -q --verify refs/stash)
25
+ git stash save -q --keep-index
26
+ new_stash=$(git rev-parse -q --verify refs/stash)
27
+
28
+ # If there were no changes (e.g., `--amend` or `--allow-empty`)
29
+ # then nothing was stashed, and we should skip everything,
30
+ # including the tests themselves. (Presumably the tests passed
31
+ # on the previous commit, so there is no need to re-run them.)
32
+ if [ "$old_stash" = "$new_stash" ]; then
33
+ echo "No changes, skipping lint." 1>&2
34
+ exit 0
35
+ fi
36
+
37
+ # Run tests
38
+ cardboardlinter --refspec HEAD -n auto
39
+ status=$?
40
+
41
+ # Restore changes
42
+ echo "Restoring working tree changes..." 1>&2
43
+ git reset --hard -q && git stash apply --index -q && git stash drop -q
44
+
45
+ # Exit with status from test-run: nonzero prevents commit
46
+ exit $status
47
+ EOF
48
+ chmod +x .git/hooks/pre-commit
49
+ ```
50
+
51
+ This will run the linters on just the changes made in your commit.
LICENSE.txt ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Mozilla Public License Version 2.0
2
+ ==================================
3
+
4
+ 1. Definitions
5
+ --------------
6
+
7
+ 1.1. "Contributor"
8
+ means each individual or legal entity that creates, contributes to
9
+ the creation of, or owns Covered Software.
10
+
11
+ 1.2. "Contributor Version"
12
+ means the combination of the Contributions of others (if any) used
13
+ by a Contributor and that particular Contributor's Contribution.
14
+
15
+ 1.3. "Contribution"
16
+ means Covered Software of a particular Contributor.
17
+
18
+ 1.4. "Covered Software"
19
+ means Source Code Form to which the initial Contributor has attached
20
+ the notice in Exhibit A, the Executable Form of such Source Code
21
+ Form, and Modifications of such Source Code Form, in each case
22
+ including portions thereof.
23
+
24
+ 1.5. "Incompatible With Secondary Licenses"
25
+ means
26
+
27
+ (a) that the initial Contributor has attached the notice described
28
+ in Exhibit B to the Covered Software; or
29
+
30
+ (b) that the Covered Software was made available under the terms of
31
+ version 1.1 or earlier of the License, but not also under the
32
+ terms of a Secondary License.
33
+
34
+ 1.6. "Executable Form"
35
+ means any form of the work other than Source Code Form.
36
+
37
+ 1.7. "Larger Work"
38
+ means a work that combines Covered Software with other material, in
39
+ a separate file or files, that is not Covered Software.
40
+
41
+ 1.8. "License"
42
+ means this document.
43
+
44
+ 1.9. "Licensable"
45
+ means having the right to grant, to the maximum extent possible,
46
+ whether at the time of the initial grant or subsequently, any and
47
+ all of the rights conveyed by this License.
48
+
49
+ 1.10. "Modifications"
50
+ means any of the following:
51
+
52
+ (a) any file in Source Code Form that results from an addition to,
53
+ deletion from, or modification of the contents of Covered
54
+ Software; or
55
+
56
+ (b) any new file in Source Code Form that contains any Covered
57
+ Software.
58
+
59
+ 1.11. "Patent Claims" of a Contributor
60
+ means any patent claim(s), including without limitation, method,
61
+ process, and apparatus claims, in any patent Licensable by such
62
+ Contributor that would be infringed, but for the grant of the
63
+ License, by the making, using, selling, offering for sale, having
64
+ made, import, or transfer of either its Contributions or its
65
+ Contributor Version.
66
+
67
+ 1.12. "Secondary License"
68
+ means either the GNU General Public License, Version 2.0, the GNU
69
+ Lesser General Public License, Version 2.1, the GNU Affero General
70
+ Public License, Version 3.0, or any later versions of those
71
+ licenses.
72
+
73
+ 1.13. "Source Code Form"
74
+ means the form of the work preferred for making modifications.
75
+
76
+ 1.14. "You" (or "Your")
77
+ means an individual or a legal entity exercising rights under this
78
+ License. For legal entities, "You" includes any entity that
79
+ controls, is controlled by, or is under common control with You. For
80
+ purposes of this definition, "control" means (a) the power, direct
81
+ or indirect, to cause the direction or management of such entity,
82
+ whether by contract or otherwise, or (b) ownership of more than
83
+ fifty percent (50%) of the outstanding shares or beneficial
84
+ ownership of such entity.
85
+
86
+ 2. License Grants and Conditions
87
+ --------------------------------
88
+
89
+ 2.1. Grants
90
+
91
+ Each Contributor hereby grants You a world-wide, royalty-free,
92
+ non-exclusive license:
93
+
94
+ (a) under intellectual property rights (other than patent or trademark)
95
+ Licensable by such Contributor to use, reproduce, make available,
96
+ modify, display, perform, distribute, and otherwise exploit its
97
+ Contributions, either on an unmodified basis, with Modifications, or
98
+ as part of a Larger Work; and
99
+
100
+ (b) under Patent Claims of such Contributor to make, use, sell, offer
101
+ for sale, have made, import, and otherwise transfer either its
102
+ Contributions or its Contributor Version.
103
+
104
+ 2.2. Effective Date
105
+
106
+ The licenses granted in Section 2.1 with respect to any Contribution
107
+ become effective for each Contribution on the date the Contributor first
108
+ distributes such Contribution.
109
+
110
+ 2.3. Limitations on Grant Scope
111
+
112
+ The licenses granted in this Section 2 are the only rights granted under
113
+ this License. No additional rights or licenses will be implied from the
114
+ distribution or licensing of Covered Software under this License.
115
+ Notwithstanding Section 2.1(b) above, no patent license is granted by a
116
+ Contributor:
117
+
118
+ (a) for any code that a Contributor has removed from Covered Software;
119
+ or
120
+
121
+ (b) for infringements caused by: (i) Your and any other third party's
122
+ modifications of Covered Software, or (ii) the combination of its
123
+ Contributions with other software (except as part of its Contributor
124
+ Version); or
125
+
126
+ (c) under Patent Claims infringed by Covered Software in the absence of
127
+ its Contributions.
128
+
129
+ This License does not grant any rights in the trademarks, service marks,
130
+ or logos of any Contributor (except as may be necessary to comply with
131
+ the notice requirements in Section 3.4).
132
+
133
+ 2.4. Subsequent Licenses
134
+
135
+ No Contributor makes additional grants as a result of Your choice to
136
+ distribute the Covered Software under a subsequent version of this
137
+ License (see Section 10.2) or under the terms of a Secondary License (if
138
+ permitted under the terms of Section 3.3).
139
+
140
+ 2.5. Representation
141
+
142
+ Each Contributor represents that the Contributor believes its
143
+ Contributions are its original creation(s) or it has sufficient rights
144
+ to grant the rights to its Contributions conveyed by this License.
145
+
146
+ 2.6. Fair Use
147
+
148
+ This License is not intended to limit any rights You have under
149
+ applicable copyright doctrines of fair use, fair dealing, or other
150
+ equivalents.
151
+
152
+ 2.7. Conditions
153
+
154
+ Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted
155
+ in Section 2.1.
156
+
157
+ 3. Responsibilities
158
+ -------------------
159
+
160
+ 3.1. Distribution of Source Form
161
+
162
+ All distribution of Covered Software in Source Code Form, including any
163
+ Modifications that You create or to which You contribute, must be under
164
+ the terms of this License. You must inform recipients that the Source
165
+ Code Form of the Covered Software is governed by the terms of this
166
+ License, and how they can obtain a copy of this License. You may not
167
+ attempt to alter or restrict the recipients' rights in the Source Code
168
+ Form.
169
+
170
+ 3.2. Distribution of Executable Form
171
+
172
+ If You distribute Covered Software in Executable Form then:
173
+
174
+ (a) such Covered Software must also be made available in Source Code
175
+ Form, as described in Section 3.1, and You must inform recipients of
176
+ the Executable Form how they can obtain a copy of such Source Code
177
+ Form by reasonable means in a timely manner, at a charge no more
178
+ than the cost of distribution to the recipient; and
179
+
180
+ (b) You may distribute such Executable Form under the terms of this
181
+ License, or sublicense it under different terms, provided that the
182
+ license for the Executable Form does not attempt to limit or alter
183
+ the recipients' rights in the Source Code Form under this License.
184
+
185
+ 3.3. Distribution of a Larger Work
186
+
187
+ You may create and distribute a Larger Work under terms of Your choice,
188
+ provided that You also comply with the requirements of this License for
189
+ the Covered Software. If the Larger Work is a combination of Covered
190
+ Software with a work governed by one or more Secondary Licenses, and the
191
+ Covered Software is not Incompatible With Secondary Licenses, this
192
+ License permits You to additionally distribute such Covered Software
193
+ under the terms of such Secondary License(s), so that the recipient of
194
+ the Larger Work may, at their option, further distribute the Covered
195
+ Software under the terms of either this License or such Secondary
196
+ License(s).
197
+
198
+ 3.4. Notices
199
+
200
+ You may not remove or alter the substance of any license notices
201
+ (including copyright notices, patent notices, disclaimers of warranty,
202
+ or limitations of liability) contained within the Source Code Form of
203
+ the Covered Software, except that You may alter any license notices to
204
+ the extent required to remedy known factual inaccuracies.
205
+
206
+ 3.5. Application of Additional Terms
207
+
208
+ You may choose to offer, and to charge a fee for, warranty, support,
209
+ indemnity or liability obligations to one or more recipients of Covered
210
+ Software. However, You may do so only on Your own behalf, and not on
211
+ behalf of any Contributor. You must make it absolutely clear that any
212
+ such warranty, support, indemnity, or liability obligation is offered by
213
+ You alone, and You hereby agree to indemnify every Contributor for any
214
+ liability incurred by such Contributor as a result of warranty, support,
215
+ indemnity or liability terms You offer. You may include additional
216
+ disclaimers of warranty and limitations of liability specific to any
217
+ jurisdiction.
218
+
219
+ 4. Inability to Comply Due to Statute or Regulation
220
+ ---------------------------------------------------
221
+
222
+ If it is impossible for You to comply with any of the terms of this
223
+ License with respect to some or all of the Covered Software due to
224
+ statute, judicial order, or regulation then You must: (a) comply with
225
+ the terms of this License to the maximum extent possible; and (b)
226
+ describe the limitations and the code they affect. Such description must
227
+ be placed in a text file included with all distributions of the Covered
228
+ Software under this License. Except to the extent prohibited by statute
229
+ or regulation, such description must be sufficiently detailed for a
230
+ recipient of ordinary skill to be able to understand it.
231
+
232
+ 5. Termination
233
+ --------------
234
+
235
+ 5.1. The rights granted under this License will terminate automatically
236
+ if You fail to comply with any of its terms. However, if You become
237
+ compliant, then the rights granted under this License from a particular
238
+ Contributor are reinstated (a) provisionally, unless and until such
239
+ Contributor explicitly and finally terminates Your grants, and (b) on an
240
+ ongoing basis, if such Contributor fails to notify You of the
241
+ non-compliance by some reasonable means prior to 60 days after You have
242
+ come back into compliance. Moreover, Your grants from a particular
243
+ Contributor are reinstated on an ongoing basis if such Contributor
244
+ notifies You of the non-compliance by some reasonable means, this is the
245
+ first time You have received notice of non-compliance with this License
246
+ from such Contributor, and You become compliant prior to 30 days after
247
+ Your receipt of the notice.
248
+
249
+ 5.2. If You initiate litigation against any entity by asserting a patent
250
+ infringement claim (excluding declaratory judgment actions,
251
+ counter-claims, and cross-claims) alleging that a Contributor Version
252
+ directly or indirectly infringes any patent, then the rights granted to
253
+ You by any and all Contributors for the Covered Software under Section
254
+ 2.1 of this License shall terminate.
255
+
256
+ 5.3. In the event of termination under Sections 5.1 or 5.2 above, all
257
+ end user license agreements (excluding distributors and resellers) which
258
+ have been validly granted by You or Your distributors under this License
259
+ prior to termination shall survive termination.
260
+
261
+ ************************************************************************
262
+ * *
263
+ * 6. Disclaimer of Warranty *
264
+ * ------------------------- *
265
+ * *
266
+ * Covered Software is provided under this License on an "as is" *
267
+ * basis, without warranty of any kind, either expressed, implied, or *
268
+ * statutory, including, without limitation, warranties that the *
269
+ * Covered Software is free of defects, merchantable, fit for a *
270
+ * particular purpose or non-infringing. The entire risk as to the *
271
+ * quality and performance of the Covered Software is with You. *
272
+ * Should any Covered Software prove defective in any respect, You *
273
+ * (not any Contributor) assume the cost of any necessary servicing, *
274
+ * repair, or correction. This disclaimer of warranty constitutes an *
275
+ * essential part of this License. No use of any Covered Software is *
276
+ * authorized under this License except under this disclaimer. *
277
+ * *
278
+ ************************************************************************
279
+
280
+ ************************************************************************
281
+ * *
282
+ * 7. Limitation of Liability *
283
+ * -------------------------- *
284
+ * *
285
+ * Under no circumstances and under no legal theory, whether tort *
286
+ * (including negligence), contract, or otherwise, shall any *
287
+ * Contributor, or anyone who distributes Covered Software as *
288
+ * permitted above, be liable to You for any direct, indirect, *
289
+ * special, incidental, or consequential damages of any character *
290
+ * including, without limitation, damages for lost profits, loss of *
291
+ * goodwill, work stoppage, computer failure or malfunction, or any *
292
+ * and all other commercial damages or losses, even if such party *
293
+ * shall have been informed of the possibility of such damages. This *
294
+ * limitation of liability shall not apply to liability for death or *
295
+ * personal injury resulting from such party's negligence to the *
296
+ * extent applicable law prohibits such limitation. Some *
297
+ * jurisdictions do not allow the exclusion or limitation of *
298
+ * incidental or consequential damages, so this exclusion and *
299
+ * limitation may not apply to You. *
300
+ * *
301
+ ************************************************************************
302
+
303
+ 8. Litigation
304
+ -------------
305
+
306
+ Any litigation relating to this License may be brought only in the
307
+ courts of a jurisdiction where the defendant maintains its principal
308
+ place of business and such litigation shall be governed by laws of that
309
+ jurisdiction, without reference to its conflict-of-law provisions.
310
+ Nothing in this Section shall prevent a party's ability to bring
311
+ cross-claims or counter-claims.
312
+
313
+ 9. Miscellaneous
314
+ ----------------
315
+
316
+ This License represents the complete agreement concerning the subject
317
+ matter hereof. If any provision of this License is held to be
318
+ unenforceable, such provision shall be reformed only to the extent
319
+ necessary to make it enforceable. Any law or regulation which provides
320
+ that the language of a contract shall be construed against the drafter
321
+ shall not be used to construe this License against a Contributor.
322
+
323
+ 10. Versions of the License
324
+ ---------------------------
325
+
326
+ 10.1. New Versions
327
+
328
+ Mozilla Foundation is the license steward. Except as provided in Section
329
+ 10.3, no one other than the license steward has the right to modify or
330
+ publish new versions of this License. Each version will be given a
331
+ distinguishing version number.
332
+
333
+ 10.2. Effect of New Versions
334
+
335
+ You may distribute the Covered Software under the terms of the version
336
+ of the License under which You originally received the Covered Software,
337
+ or under the terms of any subsequent version published by the license
338
+ steward.
339
+
340
+ 10.3. Modified Versions
341
+
342
+ If you create software not governed by this License, and you want to
343
+ create a new license for such software, you may create and use a
344
+ modified version of this License if you rename the license and remove
345
+ any references to the name of the license steward (except to note that
346
+ such modified license differs from this License).
347
+
348
+ 10.4. Distributing Source Code Form that is Incompatible With Secondary
349
+ Licenses
350
+
351
+ If You choose to distribute Source Code Form that is Incompatible With
352
+ Secondary Licenses under the terms of this version of the License, the
353
+ notice described in Exhibit B of this License must be attached.
354
+
355
+ Exhibit A - Source Code Form License Notice
356
+ -------------------------------------------
357
+
358
+ This Source Code Form is subject to the terms of the Mozilla Public
359
+ License, v. 2.0. If a copy of the MPL was not distributed with this
360
+ file, You can obtain one at http://mozilla.org/MPL/2.0/.
361
+
362
+ If it is not possible or desirable to put the notice in a particular
363
+ file, then You may include the notice in a location (such as a LICENSE
364
+ file in a relevant directory) where a recipient would be likely to look
365
+ for such a notice.
366
+
367
+ You may add additional accurate notices of copyright ownership.
368
+
369
+ Exhibit B - "Incompatible With Secondary Licenses" Notice
370
+ ---------------------------------------------------------
371
+
372
+ This Source Code Form is "Incompatible With Secondary Licenses", as
373
+ defined by the Mozilla Public License, v. 2.0.
MANIFEST.in ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ include README.md
2
+ include LICENSE.txt
3
+ include requirements.txt
4
+ recursive-include TTS *.json
5
+ recursive-include TTS *.html
6
+ recursive-include TTS *.png
7
+ recursive-include TTS *.md
8
+ recursive-include TTS *.py
9
+ recursive-include TTS *.pyx
10
+ recursive-include images *.png
11
+
README.md CHANGED
@@ -1,3 +1,281 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <img src="https://user-images.githubusercontent.com/1402048/104139991-3fd15e00-53af-11eb-8640-3a78a64641dd.png" data-canonical-src="![TTS banner](https://user-images.githubusercontent.com/1402048/104139991-3fd15e00-53af-11eb-8640-3a78a64641dd.png =250x250)
2
+ " width="256" height="256" align="right" />
3
+
4
+ # TTS: Text-to-Speech for all.
5
+
6
+ TTS is a library for advanced Text-to-Speech generation. It's built on the latest research, was designed to achieve the best trade-off among ease-of-training, speed and quality.
7
+ TTS comes with [pretrained models](https://github.com/mozilla/TTS/wiki/Released-Models), tools for measuring dataset quality and already used in **20+ languages** for products and research projects.
8
+
9
+ [![CircleCI](<https://circleci.com/gh/mozilla/TTS/tree/dev.svg?style=svg>)]()
10
+ [![License](<https://img.shields.io/badge/License-MPL%202.0-brightgreen.svg>)](https://opensource.org/licenses/MPL-2.0)
11
+ [![PyPI version](https://badge.fury.io/py/TTS.svg)](https://badge.fury.io/py/TTS)
12
+
13
+ :loudspeaker: [English Voice Samples](https://erogol.github.io/ddc-samples/) and [SoundCloud playlist](https://soundcloud.com/user-565970875/pocket-article-wavernn-and-tacotron2)
14
+
15
+ :man_cook: [TTS training recipes](https://github.com/erogol/TTS_recipes)
16
+
17
+ :page_facing_up: [Text-to-Speech paper collection](https://github.com/erogol/TTS-papers)
18
+
19
+ ## 💬 Where to ask questions
20
+ Please use our dedicated channels for questions and discussion. Help is much more valuable if it's shared publicly, so that more people can benefit from it.
21
+
22
+ | Type | Platforms |
23
+ | ------------------------------- | --------------------------------------- |
24
+ | 🚨 **Bug Reports** | [GitHub Issue Tracker] |
25
+ | ❔ **FAQ** | [TTS/Wiki](https://github.com/mozilla/TTS/wiki/FAQ) |
26
+ | 🎁 **Feature Requests & Ideas** | [GitHub Issue Tracker] |
27
+ | 👩‍💻 **Usage Questions** | [Discourse Forum] |
28
+ | 🗯 **General Discussion** | [Discourse Forum] and [Matrix Channel] |
29
+
30
+ [github issue tracker]: https://github.com/mozilla/tts/issues
31
+ [discourse forum]: https://discourse.mozilla.org/c/tts/
32
+ [matrix channel]: https://matrix.to/#/!KTePhNahjgiVumkqca:matrix.org?via=matrix.org
33
+ [Tutorials and Examples]: https://github.com/mozilla/TTS/wiki/TTS-Notebooks-and-Tutorials
34
+
35
+
36
+ ## 🔗 Links and Resources
37
+ | Type | Links |
38
+ | ------------------------------- | --------------------------------------- |
39
+ | 💾 **Installation** | [TTS/README.md](https://github.com/mozilla/TTS/tree/dev#install-tts)|
40
+ | 👩🏾‍🏫 **Tutorials and Examples** | [TTS/Wiki](https://github.com/mozilla/TTS/wiki/TTS-Notebooks-and-Tutorials) |
41
+ | 🚀 **Released Models** | [TTS/Wiki](https://github.com/mozilla/TTS/wiki/Released-Models)|
42
+ | 💻 **Docker Image** | [Repository by @synesthesiam](https://github.com/synesthesiam/docker-mozillatts)|
43
+ | 🖥️ **Demo Server** | [TTS/server](https://github.com/mozilla/TTS/tree/master/TTS/server)|
44
+ | 🤖 **Running TTS on Terminal** | [TTS/README.md](https://github.com/mozilla/TTS#example-synthesizing-speech-on-terminal-using-the-released-models)|
45
+ | ✨ **How to contribute** |[TTS/README.md](#contribution-guidelines)|
46
+
47
+ ## 🥇 TTS Performance
48
+ <p align="center"><img src="https://discourse-prod-uploads-81679984178418.s3.dualstack.us-west-2.amazonaws.com/optimized/3X/6/4/6428f980e9ec751c248e591460895f7881aec0c6_2_1035x591.png" width="800" /></p>
49
+
50
+ "Mozilla*" and "Judy*" are our models.
51
+ [Details...](https://github.com/mozilla/TTS/wiki/Mean-Opinion-Score-Results)
52
+
53
+ ## Features
54
+ - High performance Deep Learning models for Text2Speech tasks.
55
+ - Text2Spec models (Tacotron, Tacotron2, Glow-TTS, SpeedySpeech).
56
+ - Speaker Encoder to compute speaker embeddings efficiently.
57
+ - Vocoder models (MelGAN, Multiband-MelGAN, GAN-TTS, ParallelWaveGAN, WaveGrad, WaveRNN)
58
+ - Fast and efficient model training.
59
+ - Detailed training logs on console and Tensorboard.
60
+ - Support for multi-speaker TTS.
61
+ - Efficient Multi-GPUs training.
62
+ - Ability to convert PyTorch models to Tensorflow 2.0 and TFLite for inference.
63
+ - Released models in PyTorch, Tensorflow and TFLite.
64
+ - Tools to curate Text2Speech datasets under```dataset_analysis```.
65
+ - Demo server for model testing.
66
+ - Notebooks for extensive model benchmarking.
67
+ - Modular (but not too much) code base enabling easy testing for new ideas.
68
+
69
+ ## Implemented Models
70
+ ### Text-to-Spectrogram
71
+ - Tacotron: [paper](https://arxiv.org/abs/1703.10135)
72
+ - Tacotron2: [paper](https://arxiv.org/abs/1712.05884)
73
+ - Glow-TTS: [paper](https://arxiv.org/abs/2005.11129)
74
+ - Speedy-Speech: [paper](https://arxiv.org/abs/2008.03802)
75
+
76
+ ### Attention Methods
77
+ - Guided Attention: [paper](https://arxiv.org/abs/1710.08969)
78
+ - Forward Backward Decoding: [paper](https://arxiv.org/abs/1907.09006)
79
+ - Graves Attention: [paper](https://arxiv.org/abs/1907.09006)
80
+ - Double Decoder Consistency: [blog](https://erogol.com/solving-attention-problems-of-tts-models-with-double-decoder-consistency/)
81
+
82
+ ### Speaker Encoder
83
+ - GE2E: [paper](https://arxiv.org/abs/1710.10467)
84
+ - Angular Loss: [paper](https://arxiv.org/pdf/2003.11982.pdf)
85
+
86
+ ### Vocoders
87
+ - MelGAN: [paper](https://arxiv.org/abs/1910.06711)
88
+ - MultiBandMelGAN: [paper](https://arxiv.org/abs/2005.05106)
89
+ - ParallelWaveGAN: [paper](https://arxiv.org/abs/1910.11480)
90
+ - GAN-TTS discriminators: [paper](https://arxiv.org/abs/1909.11646)
91
+ - WaveRNN: [origin](https://github.com/fatchord/WaveRNN/)
92
+ - WaveGrad: [paper](https://arxiv.org/abs/2009.00713)
93
+
94
+ You can also help us implement more models. Some TTS related work can be found [here](https://github.com/erogol/TTS-papers).
95
+
96
+ ## Install TTS
97
+ TTS supports **python >= 3.6, <3.9**.
98
+
99
+ If you are only interested in [synthesizing speech](https://github.com/mozilla/TTS/tree/dev#example-synthesizing-speech-on-terminal-using-the-released-models) with the released TTS models, installing from PyPI is the easiest option.
100
+
101
+ ```bash
102
+ pip install TTS
103
+ ```
104
+
105
+ If you plan to code or train models, clone TTS and install it locally.
106
+
107
+ ```bash
108
+ git clone https://github.com/mozilla/TTS
109
+ pip install -e .
110
+ ```
111
+
112
+ ## Directory Structure
113
+ ```
114
+ |- notebooks/ (Jupyter Notebooks for model evaluation, parameter selection and data analysis.)
115
+ |- utils/ (common utilities.)
116
+ |- TTS
117
+ |- bin/ (folder for all the executables.)
118
+ |- train*.py (train your target model.)
119
+ |- distribute.py (train your TTS model using Multiple GPUs.)
120
+ |- compute_statistics.py (compute dataset statistics for normalization.)
121
+ |- convert*.py (convert target torch model to TF.)
122
+ |- tts/ (text to speech models)
123
+ |- layers/ (model layer definitions)
124
+ |- models/ (model definitions)
125
+ |- tf/ (Tensorflow 2 utilities and model implementations)
126
+ |- utils/ (model specific utilities.)
127
+ |- speaker_encoder/ (Speaker Encoder models.)
128
+ |- (same)
129
+ |- vocoder/ (Vocoder models.)
130
+ |- (same)
131
+ ```
132
+
133
+ ## Sample Model Output
134
+ Below you see Tacotron model state after 16K iterations with batch-size 32 with LJSpeech dataset.
135
+
136
+ > "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase the grey matter in the parts of the brain responsible for emotional regulation and learning."
137
+
138
+ Audio examples: [soundcloud](https://soundcloud.com/user-565970875/pocket-article-wavernn-and-tacotron2)
139
+
140
+ <img src="images/example_model_output.png?raw=true" alt="example_output" width="400"/>
141
+
142
+ ## Datasets and Data-Loading
143
+ TTS provides a generic dataloader easy to use for your custom dataset.
144
+ You just need to write a simple function to format the dataset. Check ```datasets/preprocess.py``` to see some examples.
145
+ After that, you need to set ```dataset``` fields in ```config.json```.
146
+
147
+ Some of the public datasets that we successfully applied TTS:
148
+
149
+ - [LJ Speech](https://keithito.com/LJ-Speech-Dataset/)
150
+ - [Nancy](http://www.cstr.ed.ac.uk/projects/blizzard/2011/lessac_blizzard2011/)
151
+ - [TWEB](https://www.kaggle.com/bryanpark/the-world-english-bible-speech-dataset)
152
+ - [M-AI-Labs](http://www.caito.de/2019/01/the-m-ailabs-speech-dataset/)
153
+ - [LibriTTS](https://openslr.org/60/)
154
+ - [Spanish](https://drive.google.com/file/d/1Sm_zyBo67XHkiFhcRSQ4YaHPYM0slO_e/view?usp=sharing) - thx! @carlfm01
155
+
156
+ ## Example: Synthesizing Speech on Terminal Using the Released Models.
157
+
158
+ After the installation, TTS provides a CLI interface for synthesizing speech using pre-trained models. You can either use your own model or the release models under the TTS project.
159
+
160
+ Listing released TTS models.
161
+ ```bash
162
+ tts --list_models
163
+ ```
164
+
165
+ Run a tts and a vocoder model from the released model list. (Simply copy and paste the full model names from the list as arguments for the command below.)
166
+ ```bash
167
+ tts --text "Text for TTS" \
168
+ --model_name "<type>/<language>/<dataset>/<model_name>" \
169
+ --vocoder_name "<type>/<language>/<dataset>/<model_name>" \
170
+ --out_path folder/to/save/output/
171
+ ```
172
+
173
+ Run your own TTS model (Using Griffin-Lim Vocoder)
174
+ ```bash
175
+ tts --text "Text for TTS" \
176
+ --model_path path/to/model.pth.tar \
177
+ --config_path path/to/config.json \
178
+ --out_path output/path/speech.wav
179
+ ```
180
+
181
+ Run your own TTS and Vocoder models
182
+ ```bash
183
+ tts --text "Text for TTS" \
184
+ --model_path path/to/config.json \
185
+ --config_path path/to/model.pth.tar \
186
+ --out_path output/path/speech.wav \
187
+ --vocoder_path path/to/vocoder.pth.tar \
188
+ --vocoder_config_path path/to/vocoder_config.json
189
+ ```
190
+
191
+ **Note:** You can use ```./TTS/bin/synthesize.py``` if you prefer running ```tts``` from the TTS project folder.
192
+
193
+ ## Example: Training and Fine-tuning LJ-Speech Dataset
194
+ Here you can find a [CoLab](https://gist.github.com/erogol/97516ad65b44dbddb8cd694953187c5b) notebook for a hands-on example, training LJSpeech. Or you can manually follow the guideline below.
195
+
196
+ To start with, split ```metadata.csv``` into train and validation subsets respectively ```metadata_train.csv``` and ```metadata_val.csv```. Note that for text-to-speech, validation performance might be misleading since the loss value does not directly measure the voice quality to the human ear and it also does not measure the attention module performance. Therefore, running the model with new sentences and listening to the results is the best way to go.
197
+
198
+ ```
199
+ shuf metadata.csv > metadata_shuf.csv
200
+ head -n 12000 metadata_shuf.csv > metadata_train.csv
201
+ tail -n 1100 metadata_shuf.csv > metadata_val.csv
202
+ ```
203
+
204
+ To train a new model, you need to define your own ```config.json``` to define model details, trainin configuration and more (check the examples). Then call the corressponding train script.
205
+
206
+ For instance, in order to train a tacotron or tacotron2 model on LJSpeech dataset, follow these steps.
207
+
208
+ ```bash
209
+ python TTS/bin/train_tacotron.py --config_path TTS/tts/configs/config.json
210
+ ```
211
+
212
+ To fine-tune a model, use ```--restore_path```.
213
+
214
+ ```bash
215
+ python TTS/bin/train_tacotron.py --config_path TTS/tts/configs/config.json --restore_path /path/to/your/model.pth.tar
216
+ ```
217
+
218
+ To continue an old training run, use ```--continue_path```.
219
+
220
+ ```bash
221
+ python TTS/bin/train_tacotron.py --continue_path /path/to/your/run_folder/
222
+ ```
223
+
224
+ For multi-GPU training, call ```distribute.py```. It runs any provided train script in multi-GPU setting.
225
+
226
+ ```bash
227
+ CUDA_VISIBLE_DEVICES="0,1,4" python TTS/bin/distribute.py --script train_tacotron.py --config_path TTS/tts/configs/config.json
228
+ ```
229
+
230
+ Each run creates a new output folder accomodating used ```config.json```, model checkpoints and tensorboard logs.
231
+
232
+ In case of any error or intercepted execution, if there is no checkpoint yet under the output folder, the whole folder is going to be removed.
233
+
234
+ You can also enjoy Tensorboard, if you point Tensorboard argument```--logdir``` to the experiment folder.
235
+
236
+ ## Contribution Guidelines
237
+ This repository is governed by Mozilla's code of conduct and etiquette guidelines. For more details, please read the [Mozilla Community Participation Guidelines.](https://www.mozilla.org/about/governance/policies/participation/)
238
+
239
+ 1. Create a new branch.
240
+ 2. Implement your changes.
241
+ 3. (if applicable) Add [Google Style](https://google.github.io/styleguide/pyguide.html#381-docstrings) docstrings.
242
+ 4. (if applicable) Implement a test case under ```tests``` folder.
243
+ 5. (Optional but Prefered) Run tests.
244
+ ```bash
245
+ ./run_tests.sh
246
+ ```
247
+ 6. Run the linter.
248
+ ```bash
249
+ pip install pylint cardboardlint
250
+ cardboardlinter --refspec master
251
+ ```
252
+ 7. Send a PR to ```dev``` branch, explain what the change is about.
253
+ 8. Let us discuss until we make it perfect :).
254
+ 9. We merge it to the ```dev``` branch once things look good.
255
+
256
+ Feel free to ping us at any step you need help using our communication channels.
257
+
258
+ ## Collaborative Experimentation Guide
259
+ If you like to use TTS to try a new idea and like to share your experiments with the community, we urge you to use the following guideline for a better collaboration.
260
+ (If you have an idea for better collaboration, let us know)
261
+ - Create a new branch.
262
+ - Open an issue pointing your branch.
263
+ - Explain your idea and experiment.
264
+ - Share your results regularly. (Tensorboard log files, audio results, visuals etc.)
265
+
266
+ ## Major TODOs
267
+ - [x] Implement the model.
268
+ - [x] Generate human-like speech on LJSpeech dataset.
269
+ - [x] Generate human-like speech on a different dataset (Nancy) (TWEB).
270
+ - [x] Train TTS with r=1 successfully.
271
+ - [x] Enable process based distributed training. Similar to (https://github.com/fastai/imagenet-fast/).
272
+ - [x] Adapting Neural Vocoder. TTS works with WaveRNN and ParallelWaveGAN (https://github.com/erogol/WaveRNN and https://github.com/erogol/ParallelWaveGAN)
273
+ - [x] Multi-speaker embedding.
274
+ - [x] Model optimization (model export, model pruning etc.)
275
+
276
+ ### Acknowledgement
277
+ - https://github.com/keithito/tacotron (Dataset pre-processing)
278
+ - https://github.com/r9y9/tacotron_pytorch (Initial Tacotron architecture)
279
+ - https://github.com/kan-bayashi/ParallelWaveGAN (vocoder library)
280
+ - https://github.com/jaywalnut310/glow-tts (Original Glow-TTS implementation)
281
+ - https://github.com/fatchord/WaveRNN/ (Original WaveRNN implementation)
TTS/.models.json ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "tts_models":{
3
+ "en":{
4
+ "ljspeech":{
5
+ "glow-tts":{
6
+ "description": "",
7
+ "model_file": "1NFsfhH8W8AgcfJ-BsL8CYAwQfZ5k4T-n",
8
+ "config_file": "1IAROF3yy9qTK43vG_-R67y3Py9yYbD6t",
9
+ "stats_file": null,
10
+ "commit": ""
11
+ },
12
+ "tacotron2-DCA": {
13
+ "description": "",
14
+ "model_file": "1CFoPDQBnhfBFu2Gc0TBSJn8o-TuNKQn7",
15
+ "config_file": "1lWSscNfKet1zZSJCNirOn7v9bigUZ8C1",
16
+ "stats_file": "1qevpGRVHPmzfiRBNuugLMX62x1k7B5vK",
17
+ "commit": ""
18
+ },
19
+ "speedy-speech-wn":{
20
+ "description": "Speedy Speech model with wavenet decoder.",
21
+ "model_file": "1VXAwiq6N-Viq3rsSXlf43bdoi0jSvMAJ",
22
+ "config_file": "1KvZilhsNP3EumVggDcD46yd834eO5hR3",
23
+ "stats_file": "1Ju7apZ5JlgsVECcETL-GEx3DRoNzWfkR",
24
+ "commit": "77b6145"
25
+ }
26
+ }
27
+ },
28
+ "es":{
29
+ "mai":{
30
+ "tacotron2-DDC":{
31
+ "model_file": "1jZ4HvYcAXI5ZClke2iGA7qFQQJBXIovw",
32
+ "config_file": "1s7g4n-B73ChCB48AQ88_DV_8oyLth8r0",
33
+ "stats_file": "13st0CZ743v6Br5R5Qw_lH1OPQOr3M-Jv",
34
+ "commit": ""
35
+ }
36
+ }
37
+ },
38
+ "fr":{
39
+ "mai":{
40
+ "tacotron2-DDC":{
41
+ "model_file": "1qyxrrCyoXUvBG2lqVd0KqAlHj-2nZCgS",
42
+ "config_file": "1yECKeP2LI7tNv4E8yVNx1yLmCfTCpkqG",
43
+ "stats_file": "13st0CZ743v6Br5R5Qw_lH1OPQOr3M-Jv",
44
+ "commit": ""
45
+ }
46
+ }
47
+ }
48
+ },
49
+ "vocoder_models":{
50
+ "universal":{
51
+ "libri-tts":{
52
+ "wavegrad":{
53
+ "model_file": "1r2g90JaZsfCj9dJkI9ioIU6JCFMPRqi6",
54
+ "config_file": "1POrrLf5YEpZyjvWyMccj1nGCVc94mR6s",
55
+ "stats_file": "1Vwbv4t-N1i3jXqI0bgKAhShAEO097sK0",
56
+ "commit": "ea976b0"
57
+ },
58
+ "fullband-melgan":{
59
+ "model_file": "1Ty5DZdOc0F7OTGj9oJThYbL5iVu_2G0K",
60
+ "config_file": "1Rd0R_nRCrbjEdpOwq6XwZAktvugiBvmu",
61
+ "stats_file": "11oY3Tv0kQtxK_JPgxrfesa99maVXHNxU",
62
+ "commit": "4132240"
63
+ }
64
+ }
65
+ },
66
+ "en": {
67
+ "ljspeech":{
68
+ "mulitband-melgan":{
69
+ "model_file": "1Ty5DZdOc0F7OTGj9oJThYbL5iVu_2G0K",
70
+ "config_file": "1Rd0R_nRCrbjEdpOwq6XwZAktvugiBvmu",
71
+ "stats_file": "11oY3Tv0kQtxK_JPgxrfesa99maVXHNxU",
72
+ "commit": "ea976b0"
73
+ }
74
+ }
75
+ }
76
+ }
77
+ }
TTS/__init__.py ADDED
File without changes
TTS/bin/__init__.py ADDED
File without changes
TTS/bin/compute_attention_masks.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import importlib
3
+ import os
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch.utils.data import DataLoader
8
+ from tqdm import tqdm
9
+ from argparse import RawTextHelpFormatter
10
+ from TTS.tts.datasets.TTSDataset import MyDataset
11
+ from TTS.tts.utils.generic_utils import setup_model
12
+ from TTS.tts.utils.io import load_checkpoint
13
+ from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
14
+ from TTS.utils.audio import AudioProcessor
15
+ from TTS.utils.io import load_config
16
+
17
+
18
+ if __name__ == '__main__':
19
+ parser = argparse.ArgumentParser(
20
+ description='''Extract attention masks from trained Tacotron/Tacotron2 models.
21
+ These masks can be used for different purposes including training a TTS model with a Duration Predictor.\n\n'''
22
+
23
+ '''Each attention mask is written to the same path as the input wav file with ".npy" file extension.
24
+ (e.g. path/bla.wav (wav file) --> path/bla.npy (attention mask))\n'''
25
+
26
+ '''
27
+ Example run:
28
+ CUDA_VISIBLE_DEVICE="0" python TTS/bin/compute_attention_masks.py
29
+ --model_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_200000.pth.tar
30
+ --config_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json
31
+ --dataset_metafile /root/LJSpeech-1.1/metadata.csv
32
+ --data_path /root/LJSpeech-1.1/
33
+ --batch_size 32
34
+ --dataset ljspeech
35
+ --use_cuda True
36
+ ''',
37
+ formatter_class=RawTextHelpFormatter
38
+ )
39
+ parser.add_argument('--model_path',
40
+ type=str,
41
+ required=True,
42
+ help='Path to Tacotron/Tacotron2 model file ')
43
+ parser.add_argument(
44
+ '--config_path',
45
+ type=str,
46
+ required=True,
47
+ help='Path to Tacotron/Tacotron2 config file.',
48
+ )
49
+ parser.add_argument('--dataset',
50
+ type=str,
51
+ default='',
52
+ required=True,
53
+ help='Target dataset processor name from TTS.tts.dataset.preprocess.')
54
+
55
+ parser.add_argument(
56
+ '--dataset_metafile',
57
+ type=str,
58
+ default='',
59
+ required=True,
60
+ help='Dataset metafile inclusing file paths with transcripts.')
61
+ parser.add_argument(
62
+ '--data_path',
63
+ type=str,
64
+ default='',
65
+ help='Defines the data path. It overwrites config.json.')
66
+ parser.add_argument('--use_cuda',
67
+ type=bool,
68
+ default=False,
69
+ help="enable/disable cuda.")
70
+
71
+ parser.add_argument(
72
+ '--batch_size',
73
+ default=16,
74
+ type=int,
75
+ help='Batch size for the model. Use batch_size=1 if you have no CUDA.')
76
+ args = parser.parse_args()
77
+
78
+ C = load_config(args.config_path)
79
+ ap = AudioProcessor(**C.audio)
80
+
81
+ # if the vocabulary was passed, replace the default
82
+ if 'characters' in C.keys():
83
+ symbols, phonemes = make_symbols(**C.characters)
84
+
85
+ # load the model
86
+ num_chars = len(phonemes) if C.use_phonemes else len(symbols)
87
+ # TODO: handle multi-speaker
88
+ model = setup_model(num_chars, num_speakers=0, c=C)
89
+ model, _ = load_checkpoint(model, args.model_path, None, args.use_cuda)
90
+ model.eval()
91
+
92
+ # data loader
93
+ preprocessor = importlib.import_module('TTS.tts.datasets.preprocess')
94
+ preprocessor = getattr(preprocessor, args.dataset)
95
+ meta_data = preprocessor(args.data_path, args.dataset_metafile)
96
+ dataset = MyDataset(model.decoder.r,
97
+ C.text_cleaner,
98
+ compute_linear_spec=False,
99
+ ap=ap,
100
+ meta_data=meta_data,
101
+ tp=C.characters if 'characters' in C.keys() else None,
102
+ add_blank=C['add_blank'] if 'add_blank' in C.keys() else False,
103
+ use_phonemes=C.use_phonemes,
104
+ phoneme_cache_path=C.phoneme_cache_path,
105
+ phoneme_language=C.phoneme_language,
106
+ enable_eos_bos=C.enable_eos_bos_chars)
107
+
108
+ dataset.sort_items()
109
+ loader = DataLoader(dataset,
110
+ batch_size=args.batch_size,
111
+ num_workers=4,
112
+ collate_fn=dataset.collate_fn,
113
+ shuffle=False,
114
+ drop_last=False)
115
+
116
+ # compute attentions
117
+ file_paths = []
118
+ with torch.no_grad():
119
+ for data in tqdm(loader):
120
+ # setup input data
121
+ text_input = data[0]
122
+ text_lengths = data[1]
123
+ linear_input = data[3]
124
+ mel_input = data[4]
125
+ mel_lengths = data[5]
126
+ stop_targets = data[6]
127
+ item_idxs = data[7]
128
+
129
+ # dispatch data to GPU
130
+ if args.use_cuda:
131
+ text_input = text_input.cuda()
132
+ text_lengths = text_lengths.cuda()
133
+ mel_input = mel_input.cuda()
134
+ mel_lengths = mel_lengths.cuda()
135
+
136
+ mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(
137
+ text_input, text_lengths, mel_input)
138
+
139
+ alignments = alignments.detach()
140
+ for idx, alignment in enumerate(alignments):
141
+ item_idx = item_idxs[idx]
142
+ # interpolate if r > 1
143
+ alignment = torch.nn.functional.interpolate(
144
+ alignment.transpose(0, 1).unsqueeze(0),
145
+ size=None,
146
+ scale_factor=model.decoder.r,
147
+ mode='nearest',
148
+ align_corners=None,
149
+ recompute_scale_factor=None).squeeze(0).transpose(0, 1)
150
+ # remove paddings
151
+ alignment = alignment[:mel_lengths[idx], :text_lengths[idx]].cpu().numpy()
152
+ # set file paths
153
+ wav_file_name = os.path.basename(item_idx)
154
+ align_file_name = os.path.splitext(wav_file_name)[0] + '.npy'
155
+ file_path = item_idx.replace(wav_file_name, align_file_name)
156
+ # save output
157
+ file_paths.append([item_idx, file_path])
158
+ np.save(file_path, alignment)
159
+
160
+ # ourput metafile
161
+ metafile = os.path.join(args.data_path, "metadata_attn_mask.txt")
162
+
163
+ with open(metafile, "w") as f:
164
+ for p in file_paths:
165
+ f.write(f"{p[0]}|{p[1]}\n")
166
+ print(f" >> Metafile created: {metafile}")
TTS/bin/compute_embeddings.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import os
4
+
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+
8
+ import torch
9
+ from TTS.speaker_encoder.model import SpeakerEncoder
10
+ from TTS.utils.audio import AudioProcessor
11
+ from TTS.utils.io import load_config
12
+ from TTS.tts.utils.speakers import save_speaker_mapping
13
+ from TTS.tts.datasets.preprocess import load_meta_data
14
+
15
+ parser = argparse.ArgumentParser(
16
+ description='Compute embedding vectors for each wav file in a dataset. If "target_dataset" is defined, it generates "speakers.json" necessary for training a multi-speaker model.')
17
+ parser.add_argument(
18
+ 'model_path',
19
+ type=str,
20
+ help='Path to model outputs (checkpoint, tensorboard etc.).')
21
+ parser.add_argument(
22
+ 'config_path',
23
+ type=str,
24
+ help='Path to config file for training.',
25
+ )
26
+ parser.add_argument(
27
+ 'data_path',
28
+ type=str,
29
+ help='Data path for wav files - directory or CSV file')
30
+ parser.add_argument(
31
+ 'output_path',
32
+ type=str,
33
+ help='path for training outputs.')
34
+ parser.add_argument(
35
+ '--target_dataset',
36
+ type=str,
37
+ default='',
38
+ help='Target dataset to pick a processor from TTS.tts.dataset.preprocess. Necessary to create a speakers.json file.'
39
+ )
40
+ parser.add_argument(
41
+ '--use_cuda', type=bool, help='flag to set cuda.', default=False
42
+ )
43
+ parser.add_argument(
44
+ '--separator', type=str, help='Separator used in file if CSV is passed for data_path', default='|'
45
+ )
46
+ args = parser.parse_args()
47
+
48
+
49
+ c = load_config(args.config_path)
50
+ ap = AudioProcessor(**c['audio'])
51
+
52
+ data_path = args.data_path
53
+ split_ext = os.path.splitext(data_path)
54
+ sep = args.separator
55
+
56
+ if args.target_dataset != '':
57
+ # if target dataset is defined
58
+ dataset_config = [
59
+ {
60
+ "name": args.target_dataset,
61
+ "path": args.data_path,
62
+ "meta_file_train": None,
63
+ "meta_file_val": None
64
+ },
65
+ ]
66
+ wav_files, _ = load_meta_data(dataset_config, eval_split=False)
67
+ output_files = [wav_file[1].replace(data_path, args.output_path).replace(
68
+ '.wav', '.npy') for wav_file in wav_files]
69
+ else:
70
+ # if target dataset is not defined
71
+ if len(split_ext) > 0 and split_ext[1].lower() == '.csv':
72
+ # Parse CSV
73
+ print(f'CSV file: {data_path}')
74
+ with open(data_path) as f:
75
+ wav_path = os.path.join(os.path.dirname(data_path), 'wavs')
76
+ wav_files = []
77
+ print(f'Separator is: {sep}')
78
+ for line in f:
79
+ components = line.split(sep)
80
+ if len(components) != 2:
81
+ print("Invalid line")
82
+ continue
83
+ wav_file = os.path.join(wav_path, components[0] + '.wav')
84
+ #print(f'wav_file: {wav_file}')
85
+ if os.path.exists(wav_file):
86
+ wav_files.append(wav_file)
87
+ print(f'Count of wavs imported: {len(wav_files)}')
88
+ else:
89
+ # Parse all wav files in data_path
90
+ wav_files = glob.glob(data_path + '/**/*.wav', recursive=True)
91
+
92
+ output_files = [wav_file.replace(data_path, args.output_path).replace(
93
+ '.wav', '.npy') for wav_file in wav_files]
94
+
95
+ for output_file in output_files:
96
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
97
+
98
+ # define Encoder model
99
+ model = SpeakerEncoder(**c.model)
100
+ model.load_state_dict(torch.load(args.model_path)['model'])
101
+ model.eval()
102
+ if args.use_cuda:
103
+ model.cuda()
104
+
105
+ # compute speaker embeddings
106
+ speaker_mapping = {}
107
+ for idx, wav_file in enumerate(tqdm(wav_files)):
108
+ if isinstance(wav_file, list):
109
+ speaker_name = wav_file[2]
110
+ wav_file = wav_file[1]
111
+
112
+ mel_spec = ap.melspectrogram(ap.load_wav(wav_file, sr=ap.sample_rate)).T
113
+ mel_spec = torch.FloatTensor(mel_spec[None, :, :])
114
+ if args.use_cuda:
115
+ mel_spec = mel_spec.cuda()
116
+ embedd = model.compute_embedding(mel_spec)
117
+ embedd = embedd.detach().cpu().numpy()
118
+ np.save(output_files[idx], embedd)
119
+
120
+ if args.target_dataset != '':
121
+ # create speaker_mapping if target dataset is defined
122
+ wav_file_name = os.path.basename(wav_file)
123
+ speaker_mapping[wav_file_name] = {}
124
+ speaker_mapping[wav_file_name]['name'] = speaker_name
125
+ speaker_mapping[wav_file_name]['embedding'] = embedd.flatten().tolist()
126
+
127
+ if args.target_dataset != '':
128
+ # save speaker_mapping if target dataset is defined
129
+ mapping_file_path = os.path.join(args.output_path, 'speakers.json')
130
+ save_speaker_mapping(args.output_path, speaker_mapping)
TTS/bin/compute_statistics.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import os
5
+ import glob
6
+ import argparse
7
+
8
+ import numpy as np
9
+ from tqdm import tqdm
10
+
11
+ from TTS.tts.datasets.preprocess import load_meta_data
12
+ from TTS.utils.io import load_config
13
+ from TTS.utils.audio import AudioProcessor
14
+
15
+
16
+ def main():
17
+ """Run preprocessing process."""
18
+ parser = argparse.ArgumentParser(
19
+ description="Compute mean and variance of spectrogtram features.")
20
+ parser.add_argument("--config_path", type=str, required=True,
21
+ help="TTS config file path to define audio processin parameters.")
22
+ parser.add_argument("--out_path", default=None, type=str,
23
+ help="directory to save the output file.")
24
+ args = parser.parse_args()
25
+
26
+ # load config
27
+ CONFIG = load_config(args.config_path)
28
+ CONFIG.audio['signal_norm'] = False # do not apply earlier normalization
29
+ CONFIG.audio['stats_path'] = None # discard pre-defined stats
30
+
31
+ # load audio processor
32
+ ap = AudioProcessor(**CONFIG.audio)
33
+
34
+ # load the meta data of target dataset
35
+ if 'data_path' in CONFIG.keys():
36
+ dataset_items = glob.glob(os.path.join(CONFIG.data_path, '**', '*.wav'), recursive=True)
37
+ else:
38
+ dataset_items = load_meta_data(CONFIG.datasets)[0] # take only train data
39
+ print(f" > There are {len(dataset_items)} files.")
40
+
41
+ mel_sum = 0
42
+ mel_square_sum = 0
43
+ linear_sum = 0
44
+ linear_square_sum = 0
45
+ N = 0
46
+ for item in tqdm(dataset_items):
47
+ # compute features
48
+ wav = ap.load_wav(item if isinstance(item, str) else item[1])
49
+ linear = ap.spectrogram(wav)
50
+ mel = ap.melspectrogram(wav)
51
+
52
+ # compute stats
53
+ N += mel.shape[1]
54
+ mel_sum += mel.sum(1)
55
+ linear_sum += linear.sum(1)
56
+ mel_square_sum += (mel ** 2).sum(axis=1)
57
+ linear_square_sum += (linear ** 2).sum(axis=1)
58
+
59
+ mel_mean = mel_sum / N
60
+ mel_scale = np.sqrt(mel_square_sum / N - mel_mean ** 2)
61
+ linear_mean = linear_sum / N
62
+ linear_scale = np.sqrt(linear_square_sum / N - linear_mean ** 2)
63
+
64
+ output_file_path = args.out_path
65
+ stats = {}
66
+ stats['mel_mean'] = mel_mean
67
+ stats['mel_std'] = mel_scale
68
+ stats['linear_mean'] = linear_mean
69
+ stats['linear_std'] = linear_scale
70
+
71
+ print(f' > Avg mel spec mean: {mel_mean.mean()}')
72
+ print(f' > Avg mel spec scale: {mel_scale.mean()}')
73
+ print(f' > Avg linear spec mean: {linear_mean.mean()}')
74
+ print(f' > Avg lienar spec scale: {linear_scale.mean()}')
75
+
76
+ # set default config values for mean-var scaling
77
+ CONFIG.audio['stats_path'] = output_file_path
78
+ CONFIG.audio['signal_norm'] = True
79
+ # remove redundant values
80
+ del CONFIG.audio['max_norm']
81
+ del CONFIG.audio['min_level_db']
82
+ del CONFIG.audio['symmetric_norm']
83
+ del CONFIG.audio['clip_norm']
84
+ stats['audio_config'] = CONFIG.audio
85
+ np.save(output_file_path, stats, allow_pickle=True)
86
+ print(f' > stats saved to {output_file_path}')
87
+
88
+
89
+ if __name__ == "__main__":
90
+ main()
TTS/bin/convert_melgan_tflite.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Convert Tensorflow Tacotron2 model to TF-Lite binary
2
+
3
+ import argparse
4
+
5
+ from TTS.utils.io import load_config
6
+ from TTS.vocoder.tf.utils.generic_utils import setup_generator
7
+ from TTS.vocoder.tf.utils.io import load_checkpoint
8
+ from TTS.vocoder.tf.utils.tflite import convert_melgan_to_tflite
9
+
10
+
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument('--tf_model',
13
+ type=str,
14
+ help='Path to target torch model to be converted to TF.')
15
+ parser.add_argument('--config_path',
16
+ type=str,
17
+ help='Path to config file of torch model.')
18
+ parser.add_argument('--output_path',
19
+ type=str,
20
+ help='path to tflite output binary.')
21
+ args = parser.parse_args()
22
+
23
+ # Set constants
24
+ CONFIG = load_config(args.config_path)
25
+
26
+ # load the model
27
+ model = setup_generator(CONFIG)
28
+ model.build_inference()
29
+ model = load_checkpoint(model, args.tf_model)
30
+
31
+ # create tflite model
32
+ tflite_model = convert_melgan_to_tflite(model, output_path=args.output_path)
TTS/bin/convert_melgan_torch_to_tf.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from difflib import SequenceMatcher
3
+ import os
4
+
5
+ import numpy as np
6
+ import tensorflow as tf
7
+ import torch
8
+
9
+ from TTS.utils.io import load_config
10
+ from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import (
11
+ compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf)
12
+ from TTS.vocoder.tf.utils.generic_utils import \
13
+ setup_generator as setup_tf_generator
14
+ from TTS.vocoder.tf.utils.io import save_checkpoint
15
+ from TTS.vocoder.utils.generic_utils import setup_generator
16
+
17
+ # prevent GPU use
18
+ os.environ['CUDA_VISIBLE_DEVICES'] = ''
19
+
20
+ # define args
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument('--torch_model_path',
23
+ type=str,
24
+ help='Path to target torch model to be converted to TF.')
25
+ parser.add_argument('--config_path',
26
+ type=str,
27
+ help='Path to config file of torch model.')
28
+ parser.add_argument(
29
+ '--output_path',
30
+ type=str,
31
+ help='path to output file including file name to save TF model.')
32
+ args = parser.parse_args()
33
+
34
+ # load model config
35
+ config_path = args.config_path
36
+ c = load_config(config_path)
37
+ num_speakers = 0
38
+
39
+ # init torch model
40
+ model = setup_generator(c)
41
+ checkpoint = torch.load(args.torch_model_path,
42
+ map_location=torch.device('cpu'))
43
+ state_dict = checkpoint['model']
44
+ model.load_state_dict(state_dict)
45
+ model.remove_weight_norm()
46
+ state_dict = model.state_dict()
47
+
48
+ # init tf model
49
+ model_tf = setup_tf_generator(c)
50
+
51
+ common_sufix = '/.ATTRIBUTES/VARIABLE_VALUE'
52
+ # get tf_model graph by passing an input
53
+ # B x D x T
54
+ dummy_input = tf.random.uniform((7, 80, 64), dtype=tf.float32)
55
+ mel_pred = model_tf(dummy_input, training=False)
56
+
57
+ # get tf variables
58
+ tf_vars = model_tf.weights
59
+
60
+ # match variable names with fuzzy logic
61
+ torch_var_names = list(state_dict.keys())
62
+ tf_var_names = [we.name for we in model_tf.weights]
63
+ var_map = []
64
+ for tf_name in tf_var_names:
65
+ # skip re-mapped layer names
66
+ if tf_name in [name[0] for name in var_map]:
67
+ continue
68
+ tf_name_edited = convert_tf_name(tf_name)
69
+ ratios = [
70
+ SequenceMatcher(None, torch_name, tf_name_edited).ratio()
71
+ for torch_name in torch_var_names
72
+ ]
73
+ max_idx = np.argmax(ratios)
74
+ matching_name = torch_var_names[max_idx]
75
+ del torch_var_names[max_idx]
76
+ var_map.append((tf_name, matching_name))
77
+
78
+ # pass weights
79
+ tf_vars = transfer_weights_torch_to_tf(tf_vars, dict(var_map), state_dict)
80
+
81
+ # Compare TF and TORCH models
82
+ # check embedding outputs
83
+ model.eval()
84
+ dummy_input_torch = torch.ones((1, 80, 10))
85
+ dummy_input_tf = tf.convert_to_tensor(dummy_input_torch.numpy())
86
+ dummy_input_tf = tf.transpose(dummy_input_tf, perm=[0, 2, 1])
87
+ dummy_input_tf = tf.expand_dims(dummy_input_tf, 2)
88
+
89
+ out_torch = model.layers[0](dummy_input_torch)
90
+ out_tf = model_tf.model_layers[0](dummy_input_tf)
91
+ out_tf_ = tf.transpose(out_tf, perm=[0, 3, 2, 1])[:, :, 0, :]
92
+
93
+ assert compare_torch_tf(out_torch, out_tf_) < 1e-5
94
+
95
+ for i in range(1, len(model.layers)):
96
+ print(f"{i} -> {model.layers[i]} vs {model_tf.model_layers[i]}")
97
+ out_torch = model.layers[i](out_torch)
98
+ out_tf = model_tf.model_layers[i](out_tf)
99
+ out_tf_ = tf.transpose(out_tf, perm=[0, 3, 2, 1])[:, :, 0, :]
100
+ diff = compare_torch_tf(out_torch, out_tf_)
101
+ assert diff < 1e-5, diff
102
+
103
+ torch.manual_seed(0)
104
+ dummy_input_torch = torch.rand((1, 80, 100))
105
+ dummy_input_tf = tf.convert_to_tensor(dummy_input_torch.numpy())
106
+ model.inference_padding = 0
107
+ model_tf.inference_padding = 0
108
+ output_torch = model.inference(dummy_input_torch)
109
+ output_tf = model_tf(dummy_input_tf, training=False)
110
+ assert compare_torch_tf(output_torch, output_tf) < 1e-5, compare_torch_tf(
111
+ output_torch, output_tf)
112
+
113
+ # save tf model
114
+ save_checkpoint(model_tf, checkpoint['step'], checkpoint['epoch'],
115
+ args.output_path)
116
+ print(' > Model conversion is successfully completed :).')
TTS/bin/convert_tacotron2_tflite.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Convert Tensorflow Tacotron2 model to TF-Lite binary
2
+
3
+ import argparse
4
+
5
+ from TTS.utils.io import load_config
6
+ from TTS.tts.utils.text.symbols import symbols, phonemes
7
+ from TTS.tts.tf.utils.generic_utils import setup_model
8
+ from TTS.tts.tf.utils.io import load_checkpoint
9
+ from TTS.tts.tf.utils.tflite import convert_tacotron2_to_tflite
10
+
11
+
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument('--tf_model',
14
+ type=str,
15
+ help='Path to target torch model to be converted to TF.')
16
+ parser.add_argument('--config_path',
17
+ type=str,
18
+ help='Path to config file of torch model.')
19
+ parser.add_argument('--output_path',
20
+ type=str,
21
+ help='path to tflite output binary.')
22
+ args = parser.parse_args()
23
+
24
+ # Set constants
25
+ CONFIG = load_config(args.config_path)
26
+
27
+ # load the model
28
+ c = CONFIG
29
+ num_speakers = 0
30
+ num_chars = len(phonemes) if c.use_phonemes else len(symbols)
31
+ model = setup_model(num_chars, num_speakers, c, enable_tflite=True)
32
+ model.build_inference()
33
+ model = load_checkpoint(model, args.tf_model)
34
+ model.decoder.set_max_decoder_steps(1000)
35
+
36
+ # create tflite model
37
+ tflite_model = convert_tacotron2_to_tflite(model, output_path=args.output_path)
TTS/bin/convert_tacotron2_torch_to_tf.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ # %%
3
+ import argparse
4
+ from difflib import SequenceMatcher
5
+ import os
6
+ import sys
7
+ # %%
8
+ # print variable match
9
+ from pprint import pprint
10
+
11
+ import numpy as np
12
+ import tensorflow as tf
13
+ import torch
14
+ from TTS.tts.tf.models.tacotron2 import Tacotron2
15
+ from TTS.tts.tf.utils.convert_torch_to_tf_utils import (
16
+ compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf)
17
+ from TTS.tts.tf.utils.generic_utils import save_checkpoint
18
+ from TTS.tts.utils.generic_utils import setup_model
19
+ from TTS.tts.utils.text.symbols import phonemes, symbols
20
+ from TTS.utils.io import load_config
21
+
22
+ sys.path.append('/home/erogol/Projects')
23
+ os.environ['CUDA_VISIBLE_DEVICES'] = ''
24
+
25
+
26
+ parser = argparse.ArgumentParser()
27
+ parser.add_argument('--torch_model_path',
28
+ type=str,
29
+ help='Path to target torch model to be converted to TF.')
30
+ parser.add_argument('--config_path',
31
+ type=str,
32
+ help='Path to config file of torch model.')
33
+ parser.add_argument('--output_path',
34
+ type=str,
35
+ help='path to output file including file name to save TF model.')
36
+ args = parser.parse_args()
37
+
38
+ # load model config
39
+ config_path = args.config_path
40
+ c = load_config(config_path)
41
+ num_speakers = 0
42
+
43
+ # init torch model
44
+ num_chars = len(phonemes) if c.use_phonemes else len(symbols)
45
+ model = setup_model(num_chars, num_speakers, c)
46
+ checkpoint = torch.load(args.torch_model_path,
47
+ map_location=torch.device('cpu'))
48
+ state_dict = checkpoint['model']
49
+ model.load_state_dict(state_dict)
50
+
51
+ # init tf model
52
+ model_tf = Tacotron2(num_chars=num_chars,
53
+ num_speakers=num_speakers,
54
+ r=model.decoder.r,
55
+ postnet_output_dim=c.audio['num_mels'],
56
+ decoder_output_dim=c.audio['num_mels'],
57
+ attn_type=c.attention_type,
58
+ attn_win=c.windowing,
59
+ attn_norm=c.attention_norm,
60
+ prenet_type=c.prenet_type,
61
+ prenet_dropout=c.prenet_dropout,
62
+ forward_attn=c.use_forward_attn,
63
+ trans_agent=c.transition_agent,
64
+ forward_attn_mask=c.forward_attn_mask,
65
+ location_attn=c.location_attn,
66
+ attn_K=c.attention_heads,
67
+ separate_stopnet=c.separate_stopnet,
68
+ bidirectional_decoder=c.bidirectional_decoder)
69
+
70
+ # set initial layer mapping - these are not captured by the below heuristic approach
71
+ # TODO: set layer names so that we can remove these manual matching
72
+ common_sufix = '/.ATTRIBUTES/VARIABLE_VALUE'
73
+ var_map = [
74
+ ('embedding/embeddings:0', 'embedding.weight'),
75
+ ('encoder/lstm/forward_lstm/lstm_cell_1/kernel:0',
76
+ 'encoder.lstm.weight_ih_l0'),
77
+ ('encoder/lstm/forward_lstm/lstm_cell_1/recurrent_kernel:0',
78
+ 'encoder.lstm.weight_hh_l0'),
79
+ ('encoder/lstm/backward_lstm/lstm_cell_2/kernel:0',
80
+ 'encoder.lstm.weight_ih_l0_reverse'),
81
+ ('encoder/lstm/backward_lstm/lstm_cell_2/recurrent_kernel:0',
82
+ 'encoder.lstm.weight_hh_l0_reverse'),
83
+ ('encoder/lstm/forward_lstm/lstm_cell_1/bias:0',
84
+ ('encoder.lstm.bias_ih_l0', 'encoder.lstm.bias_hh_l0')),
85
+ ('encoder/lstm/backward_lstm/lstm_cell_2/bias:0',
86
+ ('encoder.lstm.bias_ih_l0_reverse', 'encoder.lstm.bias_hh_l0_reverse')),
87
+ ('attention/v/kernel:0', 'decoder.attention.v.linear_layer.weight'),
88
+ ('decoder/linear_projection/kernel:0',
89
+ 'decoder.linear_projection.linear_layer.weight'),
90
+ ('decoder/stopnet/kernel:0', 'decoder.stopnet.1.linear_layer.weight')
91
+ ]
92
+
93
+ # %%
94
+ # get tf_model graph
95
+ model_tf.build_inference()
96
+
97
+ # get tf variables
98
+ tf_vars = model_tf.weights
99
+
100
+ # match variable names with fuzzy logic
101
+ torch_var_names = list(state_dict.keys())
102
+ tf_var_names = [we.name for we in model_tf.weights]
103
+ for tf_name in tf_var_names:
104
+ # skip re-mapped layer names
105
+ if tf_name in [name[0] for name in var_map]:
106
+ continue
107
+ tf_name_edited = convert_tf_name(tf_name)
108
+ ratios = [
109
+ SequenceMatcher(None, torch_name, tf_name_edited).ratio()
110
+ for torch_name in torch_var_names
111
+ ]
112
+ max_idx = np.argmax(ratios)
113
+ matching_name = torch_var_names[max_idx]
114
+ del torch_var_names[max_idx]
115
+ var_map.append((tf_name, matching_name))
116
+
117
+ pprint(var_map)
118
+ pprint(torch_var_names)
119
+
120
+ # pass weights
121
+ tf_vars = transfer_weights_torch_to_tf(tf_vars, dict(var_map), state_dict)
122
+
123
+ # Compare TF and TORCH models
124
+ # %%
125
+ # check embedding outputs
126
+ model.eval()
127
+ input_ids = torch.randint(0, 24, (1, 128)).long()
128
+
129
+ o_t = model.embedding(input_ids)
130
+ o_tf = model_tf.embedding(input_ids.detach().numpy())
131
+ assert abs(o_t.detach().numpy() -
132
+ o_tf.numpy()).sum() < 1e-5, abs(o_t.detach().numpy() -
133
+ o_tf.numpy()).sum()
134
+
135
+ # compare encoder outputs
136
+ oo_en = model.encoder.inference(o_t.transpose(1, 2))
137
+ ooo_en = model_tf.encoder(o_t.detach().numpy(), training=False)
138
+ assert compare_torch_tf(oo_en, ooo_en) < 1e-5
139
+
140
+ #pylint: disable=redefined-builtin
141
+ # compare decoder.attention_rnn
142
+ inp = torch.rand([1, 768])
143
+ inp_tf = inp.numpy()
144
+ model.decoder._init_states(oo_en, mask=None) #pylint: disable=protected-access
145
+ output, cell_state = model.decoder.attention_rnn(inp)
146
+ states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)
147
+ output_tf, memory_state = model_tf.decoder.attention_rnn(inp_tf,
148
+ states[2],
149
+ training=False)
150
+ assert compare_torch_tf(output, output_tf).mean() < 1e-5
151
+
152
+ query = output
153
+ inputs = torch.rand([1, 128, 512])
154
+ query_tf = query.detach().numpy()
155
+ inputs_tf = inputs.numpy()
156
+
157
+ # compare decoder.attention
158
+ model.decoder.attention.init_states(inputs)
159
+ processes_inputs = model.decoder.attention.preprocess_inputs(inputs)
160
+ loc_attn, proc_query = model.decoder.attention.get_location_attention(
161
+ query, processes_inputs)
162
+ context = model.decoder.attention(query, inputs, processes_inputs, None)
163
+
164
+ attention_states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)[-1]
165
+ model_tf.decoder.attention.process_values(tf.convert_to_tensor(inputs_tf))
166
+ loc_attn_tf, proc_query_tf = model_tf.decoder.attention.get_loc_attn(query_tf, attention_states)
167
+ context_tf, attention, attention_states = model_tf.decoder.attention(query_tf, attention_states, training=False)
168
+
169
+ assert compare_torch_tf(loc_attn, loc_attn_tf).mean() < 1e-5
170
+ assert compare_torch_tf(proc_query, proc_query_tf).mean() < 1e-5
171
+ assert compare_torch_tf(context, context_tf) < 1e-5
172
+
173
+ # compare decoder.decoder_rnn
174
+ input = torch.rand([1, 1536])
175
+ input_tf = input.numpy()
176
+ model.decoder._init_states(oo_en, mask=None) #pylint: disable=protected-access
177
+ output, cell_state = model.decoder.decoder_rnn(
178
+ input, [model.decoder.decoder_hidden, model.decoder.decoder_cell])
179
+ states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)
180
+ output_tf, memory_state = model_tf.decoder.decoder_rnn(input_tf,
181
+ states[3],
182
+ training=False)
183
+ assert abs(input - input_tf).mean() < 1e-5
184
+ assert compare_torch_tf(output, output_tf).mean() < 1e-5
185
+
186
+ # compare decoder.linear_projection
187
+ input = torch.rand([1, 1536])
188
+ input_tf = input.numpy()
189
+ output = model.decoder.linear_projection(input)
190
+ output_tf = model_tf.decoder.linear_projection(input_tf, training=False)
191
+ assert compare_torch_tf(output, output_tf) < 1e-5
192
+
193
+ # compare decoder outputs
194
+ model.decoder.max_decoder_steps = 100
195
+ model_tf.decoder.set_max_decoder_steps(100)
196
+ output, align, stop = model.decoder.inference(oo_en)
197
+ states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)
198
+ output_tf, align_tf, stop_tf = model_tf.decoder(ooo_en, states, training=False)
199
+ assert compare_torch_tf(output.transpose(1, 2), output_tf) < 1e-4
200
+
201
+ # compare the whole model output
202
+ outputs_torch = model.inference(input_ids)
203
+ outputs_tf = model_tf(tf.convert_to_tensor(input_ids.numpy()))
204
+ print(abs(outputs_torch[0].numpy()[:, 0] - outputs_tf[0].numpy()[:, 0]).mean())
205
+ assert compare_torch_tf(outputs_torch[2][:, 50, :],
206
+ outputs_tf[2][:, 50, :]) < 1e-5
207
+ assert compare_torch_tf(outputs_torch[0], outputs_tf[0]) < 1e-4
208
+
209
+ # %%
210
+ # save tf model
211
+ save_checkpoint(model_tf, None, checkpoint['step'], checkpoint['epoch'],
212
+ checkpoint['r'], args.output_path)
213
+ print(' > Model conversion is successfully completed :).')
TTS/bin/distribute.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import os
5
+ import sys
6
+ import pathlib
7
+ import time
8
+ import subprocess
9
+ import argparse
10
+ import torch
11
+
12
+
13
+ def main():
14
+ """
15
+ Call train.py as a new process and pass command arguments
16
+ """
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument(
19
+ '--script',
20
+ type=str,
21
+ help='Target training script to distibute.')
22
+ parser.add_argument(
23
+ '--continue_path',
24
+ type=str,
25
+ help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
26
+ default='',
27
+ required='--config_path' not in sys.argv)
28
+ parser.add_argument(
29
+ '--restore_path',
30
+ type=str,
31
+ help='Model file to be restored. Use to finetune a model.',
32
+ default='')
33
+ parser.add_argument(
34
+ '--config_path',
35
+ type=str,
36
+ help='Path to config file for training.',
37
+ required='--continue_path' not in sys.argv
38
+ )
39
+ args = parser.parse_args()
40
+
41
+ num_gpus = torch.cuda.device_count()
42
+ group_id = time.strftime("%Y_%m_%d-%H%M%S")
43
+
44
+ # set arguments for train.py
45
+ folder_path = pathlib.Path(__file__).parent.absolute()
46
+ command = [os.path.join(folder_path, args.script)]
47
+ command.append('--continue_path={}'.format(args.continue_path))
48
+ command.append('--restore_path={}'.format(args.restore_path))
49
+ command.append('--config_path={}'.format(args.config_path))
50
+ command.append('--group_id=group_{}'.format(group_id))
51
+ command.append('')
52
+
53
+ # run processes
54
+ processes = []
55
+ for i in range(num_gpus):
56
+ my_env = os.environ.copy()
57
+ my_env["PYTHON_EGG_CACHE"] = "/tmp/tmp{}".format(i)
58
+ command[-1] = '--rank={}'.format(i)
59
+ stdout = None if i == 0 else open(os.devnull, 'w')
60
+ p = subprocess.Popen(['python3'] + command, stdout=stdout, env=my_env)
61
+ processes.append(p)
62
+ print(command)
63
+
64
+ for p in processes:
65
+ p.wait()
66
+
67
+
68
+ if __name__ == '__main__':
69
+ main()
TTS/bin/synthesize.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import argparse
5
+ import os
6
+ import sys
7
+ import string
8
+ from argparse import RawTextHelpFormatter
9
+ # pylint: disable=redefined-outer-name, unused-argument
10
+ from pathlib import Path
11
+
12
+ from TTS.utils.manage import ModelManager
13
+ from TTS.utils.synthesizer import Synthesizer
14
+
15
+
16
+ def str2bool(v):
17
+ if isinstance(v, bool):
18
+ return v
19
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
20
+ return True
21
+ if v.lower() in ('no', 'false', 'f', 'n', '0'):
22
+ return False
23
+ raise argparse.ArgumentTypeError('Boolean value expected.')
24
+
25
+
26
+ def main():
27
+ # pylint: disable=bad-continuation
28
+ parser = argparse.ArgumentParser(description='''Synthesize speech on command line.\n\n'''
29
+
30
+ '''You can either use your trained model or choose a model from the provided list.\n'''\
31
+
32
+ '''
33
+ Example runs:
34
+
35
+ # list provided models
36
+ ./TTS/bin/synthesize.py --list_models
37
+
38
+ # run a model from the list
39
+ ./TTS/bin/synthesize.py --text "Text for TTS" --model_name "<language>/<dataset>/<model_name>" --vocoder_name "<language>/<dataset>/<model_name>" --output_path
40
+
41
+ # run your own TTS model (Using Griffin-Lim Vocoder)
42
+ ./TTS/bin/synthesize.py --text "Text for TTS" --model_path path/to/model.pth.tar --config_path path/to/config.json --out_path output/path/speech.wav
43
+
44
+ # run your own TTS and Vocoder models
45
+ ./TTS/bin/synthesize.py --text "Text for TTS" --model_path path/to/config.json --config_path path/to/model.pth.tar --out_path output/path/speech.wav
46
+ --vocoder_path path/to/vocoder.pth.tar --vocoder_config_path path/to/vocoder_config.json
47
+
48
+ ''',
49
+ formatter_class=RawTextHelpFormatter)
50
+
51
+ parser.add_argument(
52
+ '--list_models',
53
+ type=str2bool,
54
+ nargs='?',
55
+ const=True,
56
+ default=False,
57
+ help='list available pre-trained tts and vocoder models.'
58
+ )
59
+ parser.add_argument(
60
+ '--text',
61
+ type=str,
62
+ default=None,
63
+ help='Text to generate speech.'
64
+ )
65
+
66
+ # Args for running pre-trained TTS models.
67
+ parser.add_argument(
68
+ '--model_name',
69
+ type=str,
70
+ default=None,
71
+ help=
72
+ 'Name of one of the pre-trained tts models in format <language>/<dataset>/<model_name>'
73
+ )
74
+ parser.add_argument(
75
+ '--vocoder_name',
76
+ type=str,
77
+ default=None,
78
+ help=
79
+ 'Name of one of the pre-trained vocoder models in format <language>/<dataset>/<model_name>'
80
+ )
81
+
82
+ # Args for running custom models
83
+ parser.add_argument(
84
+ '--config_path',
85
+ default=None,
86
+ type=str,
87
+ help='Path to model config file.'
88
+ )
89
+ parser.add_argument(
90
+ '--model_path',
91
+ type=str,
92
+ default=None,
93
+ help='Path to model file.',
94
+ )
95
+ parser.add_argument(
96
+ '--out_path',
97
+ type=str,
98
+ default=Path(__file__).resolve().parent,
99
+ help='Path to save final wav file. Wav file will be named as the given text.',
100
+ )
101
+ parser.add_argument(
102
+ '--use_cuda',
103
+ type=bool,
104
+ help='Run model on CUDA.',
105
+ default=False
106
+ )
107
+ parser.add_argument(
108
+ '--vocoder_path',
109
+ type=str,
110
+ help=
111
+ 'Path to vocoder model file. If it is not defined, model uses GL as vocoder. Please make sure that you installed vocoder library before (WaveRNN).',
112
+ default=None,
113
+ )
114
+ parser.add_argument(
115
+ '--vocoder_config_path',
116
+ type=str,
117
+ help='Path to vocoder model config file.',
118
+ default=None)
119
+
120
+ # args for multi-speaker synthesis
121
+ parser.add_argument(
122
+ '--speakers_json',
123
+ type=str,
124
+ help="JSON file for multi-speaker model.",
125
+ default=None)
126
+ parser.add_argument(
127
+ '--speaker_idx',
128
+ type=str,
129
+ help="if the tts model is trained with x-vectors, then speaker_idx is a file present in speakers.json else speaker_idx is the speaker id corresponding to a speaker in the speaker embedding layer.",
130
+ default=None)
131
+ parser.add_argument(
132
+ '--gst_style',
133
+ help="Wav path file for GST stylereference.",
134
+ default=None)
135
+
136
+ # aux args
137
+ parser.add_argument(
138
+ '--save_spectogram',
139
+ type=bool,
140
+ help="If true save raw spectogram for further (vocoder) processing in out_path.",
141
+ default=False)
142
+
143
+ args = parser.parse_args()
144
+
145
+ # load model manager
146
+ path = Path(__file__).parent / "../.models.json"
147
+ manager = ModelManager(path)
148
+
149
+ model_path = None
150
+ config_path = None
151
+ vocoder_path = None
152
+ vocoder_config_path = None
153
+
154
+ # CASE1: list pre-trained TTS models
155
+ if args.list_models:
156
+ manager.list_models()
157
+ sys.exit()
158
+
159
+ # CASE2: load pre-trained models
160
+ if args.model_name is not None:
161
+ model_path, config_path = manager.download_model(args.model_name)
162
+
163
+ if args.vocoder_name is not None:
164
+ vocoder_path, vocoder_config_path = manager.download_model(args.vocoder_name)
165
+
166
+ # CASE3: load custome models
167
+ if args.model_path is not None:
168
+ model_path = args.model_path
169
+ config_path = args.config_path
170
+
171
+ if args.vocoder_path is not None:
172
+ vocoder_path = args.vocoder_path
173
+ vocoder_config_path = args.vocoder_config_path
174
+
175
+ # RUN THE SYNTHESIS
176
+ # load models
177
+ synthesizer = Synthesizer(model_path, config_path, vocoder_path, vocoder_config_path, args.use_cuda)
178
+
179
+ use_griffin_lim = vocoder_path is None
180
+ print(" > Text: {}".format(args.text))
181
+
182
+ # # handle multi-speaker setting
183
+ # if not model_config.use_external_speaker_embedding_file and args.speaker_idx is not None:
184
+ # if args.speaker_idx.isdigit():
185
+ # args.speaker_idx = int(args.speaker_idx)
186
+ # else:
187
+ # args.speaker_idx = None
188
+ # else:
189
+ # args.speaker_idx = None
190
+
191
+ # if args.gst_style is None:
192
+ # if 'gst' in model_config.keys() and model_config.gst['gst_style_input'] is not None:
193
+ # gst_style = model_config.gst['gst_style_input']
194
+ # else:
195
+ # gst_style = None
196
+ # else:
197
+ # # check if gst_style string is a dict, if is dict convert else use string
198
+ # try:
199
+ # gst_style = json.loads(args.gst_style)
200
+ # if max(map(int, gst_style.keys())) >= model_config.gst['gst_style_tokens']:
201
+ # raise RuntimeError("The highest value of the gst_style dictionary key must be less than the number of GST Tokens, \n Highest dictionary key value: {} \n Number of GST tokens: {}".format(max(map(int, gst_style.keys())), model_config.gst['gst_style_tokens']))
202
+ # except ValueError:
203
+ # gst_style = args.gst_style
204
+
205
+ # kick it
206
+ wav = synthesizer.tts(args.text)
207
+
208
+ # save the results
209
+ file_name = args.text.replace(" ", "_")[0:20]
210
+ file_name = file_name.translate(
211
+ str.maketrans('', '', string.punctuation.replace('_', ''))) + '.wav'
212
+ out_path = os.path.join(args.out_path, file_name)
213
+ print(" > Saving output to {}".format(out_path))
214
+ synthesizer.save_wav(wav, out_path)
215
+
216
+
217
+ if __name__ == "__main__":
218
+ main()
TTS/bin/train_encoder.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import argparse
5
+ import os
6
+ import sys
7
+ import time
8
+ import traceback
9
+
10
+ import torch
11
+ from torch.utils.data import DataLoader
12
+ from TTS.speaker_encoder.dataset import MyDataset
13
+ from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss
14
+ from TTS.speaker_encoder.model import SpeakerEncoder
15
+ from TTS.speaker_encoder.utils.generic_utils import \
16
+ check_config_speaker_encoder, save_best_model
17
+ from TTS.speaker_encoder.utils.visual import plot_embeddings
18
+ from TTS.tts.datasets.preprocess import load_meta_data
19
+ from TTS.utils.audio import AudioProcessor
20
+ from TTS.utils.generic_utils import (count_parameters,
21
+ create_experiment_folder, get_git_branch,
22
+ remove_experiment_folder, set_init_dict)
23
+ from TTS.utils.io import copy_model_files, load_config
24
+ from TTS.utils.radam import RAdam
25
+ from TTS.utils.tensorboard_logger import TensorboardLogger
26
+ from TTS.utils.training import NoamLR, check_update
27
+
28
+ torch.backends.cudnn.enabled = True
29
+ torch.backends.cudnn.benchmark = True
30
+ torch.manual_seed(54321)
31
+ use_cuda = torch.cuda.is_available()
32
+ num_gpus = torch.cuda.device_count()
33
+ print(" > Using CUDA: ", use_cuda)
34
+ print(" > Number of GPUs: ", num_gpus)
35
+
36
+
37
+ def setup_loader(ap: AudioProcessor, is_val: bool=False, verbose: bool=False):
38
+ if is_val:
39
+ loader = None
40
+ else:
41
+ dataset = MyDataset(ap,
42
+ meta_data_eval if is_val else meta_data_train,
43
+ voice_len=1.6,
44
+ num_utter_per_speaker=c.num_utters_per_speaker,
45
+ num_speakers_in_batch=c.num_speakers_in_batch,
46
+ skip_speakers=False,
47
+ storage_size=c.storage["storage_size"],
48
+ sample_from_storage_p=c.storage["sample_from_storage_p"],
49
+ additive_noise=c.storage["additive_noise"],
50
+ verbose=verbose)
51
+ # sampler = DistributedSampler(dataset) if num_gpus > 1 else None
52
+ loader = DataLoader(dataset,
53
+ batch_size=c.num_speakers_in_batch,
54
+ shuffle=False,
55
+ num_workers=c.num_loader_workers,
56
+ collate_fn=dataset.collate_fn)
57
+ return loader
58
+
59
+
60
+ def train(model, criterion, optimizer, scheduler, ap, global_step):
61
+ data_loader = setup_loader(ap, is_val=False, verbose=True)
62
+ model.train()
63
+ epoch_time = 0
64
+ best_loss = float('inf')
65
+ avg_loss = 0
66
+ avg_loader_time = 0
67
+ end_time = time.time()
68
+ for _, data in enumerate(data_loader):
69
+ start_time = time.time()
70
+
71
+ # setup input data
72
+ inputs = data[0]
73
+ loader_time = time.time() - end_time
74
+ global_step += 1
75
+
76
+ # setup lr
77
+ if c.lr_decay:
78
+ scheduler.step()
79
+ optimizer.zero_grad()
80
+
81
+ # dispatch data to GPU
82
+ if use_cuda:
83
+ inputs = inputs.cuda(non_blocking=True)
84
+ # labels = labels.cuda(non_blocking=True)
85
+
86
+ # forward pass model
87
+ outputs = model(inputs)
88
+
89
+ # loss computation
90
+ loss = criterion(
91
+ outputs.view(c.num_speakers_in_batch,
92
+ outputs.shape[0] // c.num_speakers_in_batch, -1))
93
+ loss.backward()
94
+ grad_norm, _ = check_update(model, c.grad_clip)
95
+ optimizer.step()
96
+
97
+ step_time = time.time() - start_time
98
+ epoch_time += step_time
99
+
100
+ # Averaged Loss and Averaged Loader Time
101
+ avg_loss = 0.01 * loss.item() \
102
+ + 0.99 * avg_loss if avg_loss != 0 else loss.item()
103
+ avg_loader_time = 1/c.num_loader_workers * loader_time + \
104
+ (c.num_loader_workers-1) / c.num_loader_workers * avg_loader_time if avg_loader_time != 0 else loader_time
105
+ current_lr = optimizer.param_groups[0]['lr']
106
+
107
+ if global_step % c.steps_plot_stats == 0:
108
+ # Plot Training Epoch Stats
109
+ train_stats = {
110
+ "loss": avg_loss,
111
+ "lr": current_lr,
112
+ "grad_norm": grad_norm,
113
+ "step_time": step_time,
114
+ "avg_loader_time": avg_loader_time
115
+ }
116
+ tb_logger.tb_train_epoch_stats(global_step, train_stats)
117
+ figures = {
118
+ # FIXME: not constant
119
+ "UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(),
120
+ 10),
121
+ }
122
+ tb_logger.tb_train_figures(global_step, figures)
123
+
124
+ if global_step % c.print_step == 0:
125
+ print(
126
+ " | > Step:{} Loss:{:.5f} AvgLoss:{:.5f} GradNorm:{:.5f} "
127
+ "StepTime:{:.2f} LoaderTime:{:.2f} AvGLoaderTime:{:.2f} LR:{:.6f}".format(
128
+ global_step, loss.item(), avg_loss, grad_norm, step_time,
129
+ loader_time, avg_loader_time, current_lr),
130
+ flush=True)
131
+
132
+ # save best model
133
+ best_loss = save_best_model(model, optimizer, avg_loss, best_loss,
134
+ OUT_PATH, global_step)
135
+
136
+ end_time = time.time()
137
+ return avg_loss, global_step
138
+
139
+
140
+ def main(args): # pylint: disable=redefined-outer-name
141
+ # pylint: disable=global-variable-undefined
142
+ global meta_data_train
143
+ global meta_data_eval
144
+
145
+ ap = AudioProcessor(**c.audio)
146
+ model = SpeakerEncoder(input_dim=c.model['input_dim'],
147
+ proj_dim=c.model['proj_dim'],
148
+ lstm_dim=c.model['lstm_dim'],
149
+ num_lstm_layers=c.model['num_lstm_layers'])
150
+ optimizer = RAdam(model.parameters(), lr=c.lr)
151
+
152
+ if c.loss == "ge2e":
153
+ criterion = GE2ELoss(loss_method='softmax')
154
+ elif c.loss == "angleproto":
155
+ criterion = AngleProtoLoss()
156
+ else:
157
+ raise Exception("The %s not is a loss supported" % c.loss)
158
+
159
+ if args.restore_path:
160
+ checkpoint = torch.load(args.restore_path)
161
+ try:
162
+ # TODO: fix optimizer init, model.cuda() needs to be called before
163
+ # optimizer restore
164
+ # optimizer.load_state_dict(checkpoint['optimizer'])
165
+ if c.reinit_layers:
166
+ raise RuntimeError
167
+ model.load_state_dict(checkpoint['model'])
168
+ except KeyError:
169
+ print(" > Partial model initialization.")
170
+ model_dict = model.state_dict()
171
+ model_dict = set_init_dict(model_dict, checkpoint, c)
172
+ model.load_state_dict(model_dict)
173
+ del model_dict
174
+ for group in optimizer.param_groups:
175
+ group['lr'] = c.lr
176
+ print(" > Model restored from step %d" % checkpoint['step'],
177
+ flush=True)
178
+ args.restore_step = checkpoint['step']
179
+ else:
180
+ args.restore_step = 0
181
+
182
+ if use_cuda:
183
+ model = model.cuda()
184
+ criterion.cuda()
185
+
186
+ if c.lr_decay:
187
+ scheduler = NoamLR(optimizer,
188
+ warmup_steps=c.warmup_steps,
189
+ last_epoch=args.restore_step - 1)
190
+ else:
191
+ scheduler = None
192
+
193
+ num_params = count_parameters(model)
194
+ print("\n > Model has {} parameters".format(num_params), flush=True)
195
+
196
+ # pylint: disable=redefined-outer-name
197
+ meta_data_train, meta_data_eval = load_meta_data(c.datasets)
198
+
199
+ global_step = args.restore_step
200
+ _, global_step = train(model, criterion, optimizer, scheduler, ap,
201
+ global_step)
202
+
203
+
204
+ if __name__ == '__main__':
205
+ parser = argparse.ArgumentParser()
206
+ parser.add_argument(
207
+ '--restore_path',
208
+ type=str,
209
+ help='Path to model outputs (checkpoint, tensorboard etc.).',
210
+ default=0)
211
+ parser.add_argument(
212
+ '--config_path',
213
+ type=str,
214
+ required=True,
215
+ help='Path to config file for training.',
216
+ )
217
+ parser.add_argument('--debug',
218
+ type=bool,
219
+ default=True,
220
+ help='Do not verify commit integrity to run training.')
221
+ parser.add_argument(
222
+ '--data_path',
223
+ type=str,
224
+ default='',
225
+ help='Defines the data path. It overwrites config.json.')
226
+ parser.add_argument('--output_path',
227
+ type=str,
228
+ help='path for training outputs.',
229
+ default='')
230
+ parser.add_argument('--output_folder',
231
+ type=str,
232
+ default='',
233
+ help='folder name for training outputs.')
234
+ args = parser.parse_args()
235
+
236
+ # setup output paths and read configs
237
+ c = load_config(args.config_path)
238
+ check_config_speaker_encoder(c)
239
+ _ = os.path.dirname(os.path.realpath(__file__))
240
+ if args.data_path != '':
241
+ c.data_path = args.data_path
242
+
243
+ if args.output_path == '':
244
+ OUT_PATH = os.path.join(_, c.output_path)
245
+ else:
246
+ OUT_PATH = args.output_path
247
+
248
+ if args.output_folder == '':
249
+ OUT_PATH = create_experiment_folder(OUT_PATH, c.run_name, args.debug)
250
+ else:
251
+ OUT_PATH = os.path.join(OUT_PATH, args.output_folder)
252
+
253
+ new_fields = {}
254
+ if args.restore_path:
255
+ new_fields["restore_path"] = args.restore_path
256
+ new_fields["github_branch"] = get_git_branch()
257
+ copy_model_files(c, args.config_path, OUT_PATH,
258
+ new_fields)
259
+
260
+ LOG_DIR = OUT_PATH
261
+ tb_logger = TensorboardLogger(LOG_DIR, model_name='Speaker_Encoder')
262
+
263
+ try:
264
+ main(args)
265
+ except KeyboardInterrupt:
266
+ remove_experiment_folder(OUT_PATH)
267
+ try:
268
+ sys.exit(0)
269
+ except SystemExit:
270
+ os._exit(0) # pylint: disable=protected-access
271
+ except Exception: # pylint: disable=broad-except
272
+ remove_experiment_folder(OUT_PATH)
273
+ traceback.print_exc()
274
+ sys.exit(1)
TTS/bin/train_glow_tts.py ADDED
@@ -0,0 +1,657 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import argparse
5
+ import glob
6
+ import os
7
+ import sys
8
+ import time
9
+ import traceback
10
+ from random import randrange
11
+
12
+ import torch
13
+ # DISTRIBUTED
14
+ from torch.nn.parallel import DistributedDataParallel as DDP_th
15
+ from torch.utils.data import DataLoader
16
+ from torch.utils.data.distributed import DistributedSampler
17
+ from TTS.tts.datasets.preprocess import load_meta_data
18
+ from TTS.tts.datasets.TTSDataset import MyDataset
19
+ from TTS.tts.layers.losses import GlowTTSLoss
20
+ from TTS.tts.utils.generic_utils import check_config_tts, setup_model
21
+ from TTS.tts.utils.io import save_best_model, save_checkpoint
22
+ from TTS.tts.utils.measures import alignment_diagonal_score
23
+ from TTS.tts.utils.speakers import parse_speakers
24
+ from TTS.tts.utils.synthesis import synthesis
25
+ from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
26
+ from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
27
+ from TTS.utils.audio import AudioProcessor
28
+ from TTS.utils.console_logger import ConsoleLogger
29
+ from TTS.utils.distribute import init_distributed, reduce_tensor
30
+ from TTS.utils.generic_utils import (KeepAverage, count_parameters,
31
+ create_experiment_folder, get_git_branch,
32
+ remove_experiment_folder, set_init_dict)
33
+ from TTS.utils.io import copy_model_files, load_config
34
+ from TTS.utils.radam import RAdam
35
+ from TTS.utils.tensorboard_logger import TensorboardLogger
36
+ from TTS.utils.training import NoamLR, setup_torch_training_env
37
+
38
+ use_cuda, num_gpus = setup_torch_training_env(True, False)
39
+
40
+ def setup_loader(ap, r, is_val=False, verbose=False):
41
+ if is_val and not c.run_eval:
42
+ loader = None
43
+ else:
44
+ dataset = MyDataset(
45
+ r,
46
+ c.text_cleaner,
47
+ compute_linear_spec=False,
48
+ meta_data=meta_data_eval if is_val else meta_data_train,
49
+ ap=ap,
50
+ tp=c.characters if 'characters' in c.keys() else None,
51
+ add_blank=c['add_blank'] if 'add_blank' in c.keys() else False,
52
+ batch_group_size=0 if is_val else c.batch_group_size *
53
+ c.batch_size,
54
+ min_seq_len=c.min_seq_len,
55
+ max_seq_len=c.max_seq_len,
56
+ phoneme_cache_path=c.phoneme_cache_path,
57
+ use_phonemes=c.use_phonemes,
58
+ phoneme_language=c.phoneme_language,
59
+ enable_eos_bos=c.enable_eos_bos_chars,
60
+ use_noise_augment=c['use_noise_augment'] and not is_val,
61
+ verbose=verbose,
62
+ speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None)
63
+
64
+ if c.use_phonemes and c.compute_input_seq_cache:
65
+ # precompute phonemes to have a better estimate of sequence lengths.
66
+ dataset.compute_input_seq(c.num_loader_workers)
67
+ dataset.sort_items()
68
+
69
+ sampler = DistributedSampler(dataset) if num_gpus > 1 else None
70
+ loader = DataLoader(
71
+ dataset,
72
+ batch_size=c.eval_batch_size if is_val else c.batch_size,
73
+ shuffle=False,
74
+ collate_fn=dataset.collate_fn,
75
+ drop_last=False,
76
+ sampler=sampler,
77
+ num_workers=c.num_val_loader_workers
78
+ if is_val else c.num_loader_workers,
79
+ pin_memory=False)
80
+ return loader
81
+
82
+
83
+ def format_data(data):
84
+ # setup input data
85
+ text_input = data[0]
86
+ text_lengths = data[1]
87
+ speaker_names = data[2]
88
+ mel_input = data[4].permute(0, 2, 1) # B x D x T
89
+ mel_lengths = data[5]
90
+ item_idx = data[7]
91
+ attn_mask = data[9]
92
+ avg_text_length = torch.mean(text_lengths.float())
93
+ avg_spec_length = torch.mean(mel_lengths.float())
94
+
95
+ if c.use_speaker_embedding:
96
+ if c.use_external_speaker_embedding_file:
97
+ # return precomputed embedding vector
98
+ speaker_c = data[8]
99
+ else:
100
+ # return speaker_id to be used by an embedding layer
101
+ speaker_c = [
102
+ speaker_mapping[speaker_name] for speaker_name in speaker_names
103
+ ]
104
+ speaker_c = torch.LongTensor(speaker_c)
105
+ else:
106
+ speaker_c = None
107
+
108
+ # dispatch data to GPU
109
+ if use_cuda:
110
+ text_input = text_input.cuda(non_blocking=True)
111
+ text_lengths = text_lengths.cuda(non_blocking=True)
112
+ mel_input = mel_input.cuda(non_blocking=True)
113
+ mel_lengths = mel_lengths.cuda(non_blocking=True)
114
+ if speaker_c is not None:
115
+ speaker_c = speaker_c.cuda(non_blocking=True)
116
+ if attn_mask is not None:
117
+ attn_mask = attn_mask.cuda(non_blocking=True)
118
+ return text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
119
+ avg_text_length, avg_spec_length, attn_mask, item_idx
120
+
121
+
122
+ def data_depended_init(data_loader, model, ap):
123
+ """Data depended initialization for activation normalization."""
124
+ if hasattr(model, 'module'):
125
+ for f in model.module.decoder.flows:
126
+ if getattr(f, "set_ddi", False):
127
+ f.set_ddi(True)
128
+ else:
129
+ for f in model.decoder.flows:
130
+ if getattr(f, "set_ddi", False):
131
+ f.set_ddi(True)
132
+
133
+ model.train()
134
+ print(" > Data depended initialization ... ")
135
+ num_iter = 0
136
+ with torch.no_grad():
137
+ for _, data in enumerate(data_loader):
138
+
139
+ # format data
140
+ text_input, text_lengths, mel_input, mel_lengths, spekaer_embed,\
141
+ _, _, attn_mask, item_idx = format_data(data)
142
+
143
+ # forward pass model
144
+ _ = model.forward(
145
+ text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=spekaer_embed)
146
+ if num_iter == c.data_dep_init_iter:
147
+ break
148
+ num_iter += 1
149
+
150
+ if hasattr(model, 'module'):
151
+ for f in model.module.decoder.flows:
152
+ if getattr(f, "set_ddi", False):
153
+ f.set_ddi(False)
154
+ else:
155
+ for f in model.decoder.flows:
156
+ if getattr(f, "set_ddi", False):
157
+ f.set_ddi(False)
158
+ return model
159
+
160
+
161
+ def train(data_loader, model, criterion, optimizer, scheduler,
162
+ ap, global_step, epoch):
163
+
164
+ model.train()
165
+ epoch_time = 0
166
+ keep_avg = KeepAverage()
167
+ if use_cuda:
168
+ batch_n_iter = int(
169
+ len(data_loader.dataset) / (c.batch_size * num_gpus))
170
+ else:
171
+ batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
172
+ end_time = time.time()
173
+ c_logger.print_train_start()
174
+ scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
175
+ for num_iter, data in enumerate(data_loader):
176
+ start_time = time.time()
177
+
178
+ # format data
179
+ text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
180
+ avg_text_length, avg_spec_length, attn_mask, item_idx = format_data(data)
181
+
182
+ loader_time = time.time() - end_time
183
+
184
+ global_step += 1
185
+ optimizer.zero_grad()
186
+
187
+ # forward pass model
188
+ with torch.cuda.amp.autocast(enabled=c.mixed_precision):
189
+ z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
190
+ text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c)
191
+
192
+ # compute loss
193
+ loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
194
+ o_dur_log, o_total_dur, text_lengths)
195
+
196
+ # backward pass with loss scaling
197
+ if c.mixed_precision:
198
+ scaler.scale(loss_dict['loss']).backward()
199
+ scaler.unscale_(optimizer)
200
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
201
+ c.grad_clip)
202
+ scaler.step(optimizer)
203
+ scaler.update()
204
+ else:
205
+ loss_dict['loss'].backward()
206
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
207
+ c.grad_clip)
208
+ optimizer.step()
209
+
210
+ # setup lr
211
+ if c.noam_schedule:
212
+ scheduler.step()
213
+
214
+ # current_lr
215
+ current_lr = optimizer.param_groups[0]['lr']
216
+
217
+ # compute alignment error (the lower the better )
218
+ align_error = 1 - alignment_diagonal_score(alignments, binary=True)
219
+ loss_dict['align_error'] = align_error
220
+
221
+ step_time = time.time() - start_time
222
+ epoch_time += step_time
223
+
224
+ # aggregate losses from processes
225
+ if num_gpus > 1:
226
+ loss_dict['log_mle'] = reduce_tensor(loss_dict['log_mle'].data, num_gpus)
227
+ loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus)
228
+ loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus)
229
+
230
+ # detach loss values
231
+ loss_dict_new = dict()
232
+ for key, value in loss_dict.items():
233
+ if isinstance(value, (int, float)):
234
+ loss_dict_new[key] = value
235
+ else:
236
+ loss_dict_new[key] = value.item()
237
+ loss_dict = loss_dict_new
238
+
239
+ # update avg stats
240
+ update_train_values = dict()
241
+ for key, value in loss_dict.items():
242
+ update_train_values['avg_' + key] = value
243
+ update_train_values['avg_loader_time'] = loader_time
244
+ update_train_values['avg_step_time'] = step_time
245
+ keep_avg.update_values(update_train_values)
246
+
247
+ # print training progress
248
+ if global_step % c.print_step == 0:
249
+ log_dict = {
250
+ "avg_spec_length": [avg_spec_length, 1], # value, precision
251
+ "avg_text_length": [avg_text_length, 1],
252
+ "step_time": [step_time, 4],
253
+ "loader_time": [loader_time, 2],
254
+ "current_lr": current_lr,
255
+ }
256
+ c_logger.print_train_step(batch_n_iter, num_iter, global_step,
257
+ log_dict, loss_dict, keep_avg.avg_values)
258
+
259
+ if args.rank == 0:
260
+ # Plot Training Iter Stats
261
+ # reduce TB load
262
+ if global_step % c.tb_plot_step == 0:
263
+ iter_stats = {
264
+ "lr": current_lr,
265
+ "grad_norm": grad_norm,
266
+ "step_time": step_time
267
+ }
268
+ iter_stats.update(loss_dict)
269
+ tb_logger.tb_train_iter_stats(global_step, iter_stats)
270
+
271
+ if global_step % c.save_step == 0:
272
+ if c.checkpoint:
273
+ # save model
274
+ save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH,
275
+ model_loss=loss_dict['loss'])
276
+
277
+ # wait all kernels to be completed
278
+ torch.cuda.synchronize()
279
+
280
+ # Diagnostic visualizations
281
+ # direct pass on model for spec predictions
282
+ target_speaker = None if speaker_c is None else speaker_c[:1]
283
+
284
+ if hasattr(model, 'module'):
285
+ spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1], g=target_speaker)
286
+ else:
287
+ spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=target_speaker)
288
+
289
+ spec_pred = spec_pred.permute(0, 2, 1)
290
+ gt_spec = mel_input.permute(0, 2, 1)
291
+ const_spec = spec_pred[0].data.cpu().numpy()
292
+ gt_spec = gt_spec[0].data.cpu().numpy()
293
+ align_img = alignments[0].data.cpu().numpy()
294
+
295
+ figures = {
296
+ "prediction": plot_spectrogram(const_spec, ap),
297
+ "ground_truth": plot_spectrogram(gt_spec, ap),
298
+ "alignment": plot_alignment(align_img),
299
+ }
300
+
301
+ tb_logger.tb_train_figures(global_step, figures)
302
+
303
+ # Sample audio
304
+ train_audio = ap.inv_melspectrogram(const_spec.T)
305
+ tb_logger.tb_train_audios(global_step,
306
+ {'TrainAudio': train_audio},
307
+ c.audio["sample_rate"])
308
+ end_time = time.time()
309
+
310
+ # print epoch stats
311
+ c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
312
+
313
+ # Plot Epoch Stats
314
+ if args.rank == 0:
315
+ epoch_stats = {"epoch_time": epoch_time}
316
+ epoch_stats.update(keep_avg.avg_values)
317
+ tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
318
+ if c.tb_model_param_stats:
319
+ tb_logger.tb_model_weights(model, global_step)
320
+ return keep_avg.avg_values, global_step
321
+
322
+
323
+ @torch.no_grad()
324
+ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
325
+ model.eval()
326
+ epoch_time = 0
327
+ keep_avg = KeepAverage()
328
+ c_logger.print_eval_start()
329
+ if data_loader is not None:
330
+ for num_iter, data in enumerate(data_loader):
331
+ start_time = time.time()
332
+
333
+ # format data
334
+ text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
335
+ _, _, attn_mask, item_idx = format_data(data)
336
+
337
+ # forward pass model
338
+ z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
339
+ text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c)
340
+
341
+ # compute loss
342
+ loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
343
+ o_dur_log, o_total_dur, text_lengths)
344
+
345
+ # step time
346
+ step_time = time.time() - start_time
347
+ epoch_time += step_time
348
+
349
+ # compute alignment score
350
+ align_error = 1 - alignment_diagonal_score(alignments)
351
+ loss_dict['align_error'] = align_error
352
+
353
+ # aggregate losses from processes
354
+ if num_gpus > 1:
355
+ loss_dict['log_mle'] = reduce_tensor(loss_dict['log_mle'].data, num_gpus)
356
+ loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus)
357
+ loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus)
358
+
359
+ # detach loss values
360
+ loss_dict_new = dict()
361
+ for key, value in loss_dict.items():
362
+ if isinstance(value, (int, float)):
363
+ loss_dict_new[key] = value
364
+ else:
365
+ loss_dict_new[key] = value.item()
366
+ loss_dict = loss_dict_new
367
+
368
+ # update avg stats
369
+ update_train_values = dict()
370
+ for key, value in loss_dict.items():
371
+ update_train_values['avg_' + key] = value
372
+ keep_avg.update_values(update_train_values)
373
+
374
+ if c.print_eval:
375
+ c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
376
+
377
+ if args.rank == 0:
378
+ # Diagnostic visualizations
379
+ # direct pass on model for spec predictions
380
+ target_speaker = None if speaker_c is None else speaker_c[:1]
381
+ if hasattr(model, 'module'):
382
+ spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1], g=target_speaker)
383
+ else:
384
+ spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=target_speaker)
385
+ spec_pred = spec_pred.permute(0, 2, 1)
386
+ gt_spec = mel_input.permute(0, 2, 1)
387
+
388
+ const_spec = spec_pred[0].data.cpu().numpy()
389
+ gt_spec = gt_spec[0].data.cpu().numpy()
390
+ align_img = alignments[0].data.cpu().numpy()
391
+
392
+ eval_figures = {
393
+ "prediction": plot_spectrogram(const_spec, ap),
394
+ "ground_truth": plot_spectrogram(gt_spec, ap),
395
+ "alignment": plot_alignment(align_img)
396
+ }
397
+
398
+ # Sample audio
399
+ eval_audio = ap.inv_melspectrogram(const_spec.T)
400
+ tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
401
+ c.audio["sample_rate"])
402
+
403
+ # Plot Validation Stats
404
+ tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
405
+ tb_logger.tb_eval_figures(global_step, eval_figures)
406
+
407
+ if args.rank == 0 and epoch >= c.test_delay_epochs:
408
+ if c.test_sentences_file is None:
409
+ test_sentences = [
410
+ "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
411
+ "Be a voice, not an echo.",
412
+ "I'm sorry Dave. I'm afraid I can't do that.",
413
+ "This cake is great. It's so delicious and moist.",
414
+ "Prior to November 22, 1963."
415
+ ]
416
+ else:
417
+ with open(c.test_sentences_file, "r") as f:
418
+ test_sentences = [s.strip() for s in f.readlines()]
419
+
420
+ # test sentences
421
+ test_audios = {}
422
+ test_figures = {}
423
+ print(" | > Synthesizing test sentences")
424
+ if c.use_speaker_embedding:
425
+ if c.use_external_speaker_embedding_file:
426
+ speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping)-1)]]['embedding']
427
+ speaker_id = None
428
+ else:
429
+ speaker_id = 0
430
+ speaker_embedding = None
431
+ else:
432
+ speaker_id = None
433
+ speaker_embedding = None
434
+
435
+ style_wav = c.get("style_wav_for_test")
436
+ for idx, test_sentence in enumerate(test_sentences):
437
+ try:
438
+ wav, alignment, _, postnet_output, _, _ = synthesis(
439
+ model,
440
+ test_sentence,
441
+ c,
442
+ use_cuda,
443
+ ap,
444
+ speaker_id=speaker_id,
445
+ speaker_embedding=speaker_embedding,
446
+ style_wav=style_wav,
447
+ truncated=False,
448
+ enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument
449
+ use_griffin_lim=True,
450
+ do_trim_silence=False)
451
+
452
+ file_path = os.path.join(AUDIO_PATH, str(global_step))
453
+ os.makedirs(file_path, exist_ok=True)
454
+ file_path = os.path.join(file_path,
455
+ "TestSentence_{}.wav".format(idx))
456
+ ap.save_wav(wav, file_path)
457
+ test_audios['{}-audio'.format(idx)] = wav
458
+ test_figures['{}-prediction'.format(idx)] = plot_spectrogram(
459
+ postnet_output, ap)
460
+ test_figures['{}-alignment'.format(idx)] = plot_alignment(
461
+ alignment)
462
+ except: #pylint: disable=bare-except
463
+ print(" !! Error creating Test Sentence -", idx)
464
+ traceback.print_exc()
465
+ tb_logger.tb_test_audios(global_step, test_audios,
466
+ c.audio['sample_rate'])
467
+ tb_logger.tb_test_figures(global_step, test_figures)
468
+ return keep_avg.avg_values
469
+
470
+
471
+ # FIXME: move args definition/parsing inside of main?
472
+ def main(args): # pylint: disable=redefined-outer-name
473
+ # pylint: disable=global-variable-undefined
474
+ global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping
475
+ # Audio processor
476
+ ap = AudioProcessor(**c.audio)
477
+ if 'characters' in c.keys():
478
+ symbols, phonemes = make_symbols(**c.characters)
479
+
480
+ # DISTRUBUTED
481
+ if num_gpus > 1:
482
+ init_distributed(args.rank, num_gpus, args.group_id,
483
+ c.distributed["backend"], c.distributed["url"])
484
+ num_chars = len(phonemes) if c.use_phonemes else len(symbols)
485
+
486
+ # load data instances
487
+ meta_data_train, meta_data_eval = load_meta_data(c.datasets)
488
+
489
+ # set the portion of the data used for training
490
+ if 'train_portion' in c.keys():
491
+ meta_data_train = meta_data_train[:int(len(meta_data_train) * c.train_portion)]
492
+ if 'eval_portion' in c.keys():
493
+ meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)]
494
+
495
+ # parse speakers
496
+ num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, OUT_PATH)
497
+
498
+ # setup model
499
+ model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim=speaker_embedding_dim)
500
+ optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0, betas=(0.9, 0.98), eps=1e-9)
501
+ criterion = GlowTTSLoss()
502
+
503
+ if args.restore_path:
504
+ checkpoint = torch.load(args.restore_path, map_location='cpu')
505
+ try:
506
+ # TODO: fix optimizer init, model.cuda() needs to be called before
507
+ # optimizer restore
508
+ optimizer.load_state_dict(checkpoint['optimizer'])
509
+ if c.reinit_layers:
510
+ raise RuntimeError
511
+ model.load_state_dict(checkpoint['model'])
512
+ except: #pylint: disable=bare-except
513
+ print(" > Partial model initialization.")
514
+ model_dict = model.state_dict()
515
+ model_dict = set_init_dict(model_dict, checkpoint['model'], c)
516
+ model.load_state_dict(model_dict)
517
+ del model_dict
518
+
519
+ for group in optimizer.param_groups:
520
+ group['initial_lr'] = c.lr
521
+ print(" > Model restored from step %d" % checkpoint['step'],
522
+ flush=True)
523
+ args.restore_step = checkpoint['step']
524
+ else:
525
+ args.restore_step = 0
526
+
527
+ if use_cuda:
528
+ model.cuda()
529
+ criterion.cuda()
530
+
531
+ # DISTRUBUTED
532
+ if num_gpus > 1:
533
+ model = DDP_th(model, device_ids=[args.rank])
534
+
535
+ if c.noam_schedule:
536
+ scheduler = NoamLR(optimizer,
537
+ warmup_steps=c.warmup_steps,
538
+ last_epoch=args.restore_step - 1)
539
+ else:
540
+ scheduler = None
541
+
542
+ num_params = count_parameters(model)
543
+ print("\n > Model has {} parameters".format(num_params), flush=True)
544
+
545
+ if 'best_loss' not in locals():
546
+ best_loss = float('inf')
547
+
548
+ # define dataloaders
549
+ train_loader = setup_loader(ap, 1, is_val=False, verbose=True)
550
+ eval_loader = setup_loader(ap, 1, is_val=True, verbose=True)
551
+
552
+ global_step = args.restore_step
553
+ model = data_depended_init(train_loader, model, ap)
554
+ for epoch in range(0, c.epochs):
555
+ c_logger.print_epoch_start(epoch, c.epochs)
556
+ train_avg_loss_dict, global_step = train(train_loader, model, criterion, optimizer,
557
+ scheduler, ap, global_step,
558
+ epoch)
559
+ eval_avg_loss_dict = evaluate(eval_loader , model, criterion, ap, global_step, epoch)
560
+ c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
561
+ target_loss = train_avg_loss_dict['avg_loss']
562
+ if c.run_eval:
563
+ target_loss = eval_avg_loss_dict['avg_loss']
564
+ best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
565
+ OUT_PATH)
566
+
567
+
568
+ if __name__ == '__main__':
569
+ parser = argparse.ArgumentParser()
570
+ parser.add_argument(
571
+ '--continue_path',
572
+ type=str,
573
+ help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
574
+ default='',
575
+ required='--config_path' not in sys.argv)
576
+ parser.add_argument(
577
+ '--restore_path',
578
+ type=str,
579
+ help='Model file to be restored. Use to finetune a model.',
580
+ default='')
581
+ parser.add_argument(
582
+ '--config_path',
583
+ type=str,
584
+ help='Path to config file for training.',
585
+ required='--continue_path' not in sys.argv
586
+ )
587
+ parser.add_argument('--debug',
588
+ type=bool,
589
+ default=False,
590
+ help='Do not verify commit integrity to run training.')
591
+
592
+ # DISTRUBUTED
593
+ parser.add_argument(
594
+ '--rank',
595
+ type=int,
596
+ default=0,
597
+ help='DISTRIBUTED: process rank for distributed training.')
598
+ parser.add_argument('--group_id',
599
+ type=str,
600
+ default="",
601
+ help='DISTRIBUTED: process group id.')
602
+ args = parser.parse_args()
603
+
604
+ if args.continue_path != '':
605
+ args.output_path = args.continue_path
606
+ args.config_path = os.path.join(args.continue_path, 'config.json')
607
+ list_of_files = glob.glob(args.continue_path + "/*.pth.tar") # * means all if need specific format then *.csv
608
+ latest_model_file = max(list_of_files, key=os.path.getctime)
609
+ args.restore_path = latest_model_file
610
+ print(f" > Training continues for {args.restore_path}")
611
+
612
+ # setup output paths and read configs
613
+ c = load_config(args.config_path)
614
+ # check_config(c)
615
+ check_config_tts(c)
616
+ _ = os.path.dirname(os.path.realpath(__file__))
617
+
618
+ if c.mixed_precision:
619
+ print(" > Mixed precision enabled.")
620
+
621
+ OUT_PATH = args.continue_path
622
+ if args.continue_path == '':
623
+ OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug)
624
+
625
+ AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
626
+
627
+ c_logger = ConsoleLogger()
628
+
629
+ if args.rank == 0:
630
+ os.makedirs(AUDIO_PATH, exist_ok=True)
631
+ new_fields = {}
632
+ if args.restore_path:
633
+ new_fields["restore_path"] = args.restore_path
634
+ new_fields["github_branch"] = get_git_branch()
635
+ copy_model_files(c, args.config_path,
636
+ OUT_PATH, new_fields)
637
+ os.chmod(AUDIO_PATH, 0o775)
638
+ os.chmod(OUT_PATH, 0o775)
639
+
640
+ LOG_DIR = OUT_PATH
641
+ tb_logger = TensorboardLogger(LOG_DIR, model_name='TTS')
642
+
643
+ # write model desc to tensorboard
644
+ tb_logger.tb_add_text('model-description', c['run_description'], 0)
645
+
646
+ try:
647
+ main(args)
648
+ except KeyboardInterrupt:
649
+ remove_experiment_folder(OUT_PATH)
650
+ try:
651
+ sys.exit(0)
652
+ except SystemExit:
653
+ os._exit(0) # pylint: disable=protected-access
654
+ except Exception: # pylint: disable=broad-except
655
+ remove_experiment_folder(OUT_PATH)
656
+ traceback.print_exc()
657
+ sys.exit(1)
TTS/bin/train_speedy_speech.py ADDED
@@ -0,0 +1,618 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import argparse
5
+ import glob
6
+ import os
7
+ import sys
8
+ import time
9
+ import traceback
10
+ import numpy as np
11
+ from random import randrange
12
+
13
+ import torch
14
+ # DISTRIBUTED
15
+ from torch.nn.parallel import DistributedDataParallel as DDP_th
16
+ from torch.utils.data import DataLoader
17
+ from torch.utils.data.distributed import DistributedSampler
18
+ from TTS.tts.datasets.preprocess import load_meta_data
19
+ from TTS.tts.datasets.TTSDataset import MyDataset
20
+ from TTS.tts.layers.losses import SpeedySpeechLoss
21
+ from TTS.tts.utils.generic_utils import check_config_tts, setup_model
22
+ from TTS.tts.utils.io import save_best_model, save_checkpoint
23
+ from TTS.tts.utils.measures import alignment_diagonal_score
24
+ from TTS.tts.utils.speakers import parse_speakers
25
+ from TTS.tts.utils.synthesis import synthesis
26
+ from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
27
+ from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
28
+ from TTS.utils.audio import AudioProcessor
29
+ from TTS.utils.console_logger import ConsoleLogger
30
+ from TTS.utils.distribute import init_distributed, reduce_tensor
31
+ from TTS.utils.generic_utils import (KeepAverage, count_parameters,
32
+ create_experiment_folder, get_git_branch,
33
+ remove_experiment_folder, set_init_dict)
34
+ from TTS.utils.io import copy_model_files, load_config
35
+ from TTS.utils.radam import RAdam
36
+ from TTS.utils.tensorboard_logger import TensorboardLogger
37
+ from TTS.utils.training import NoamLR, setup_torch_training_env
38
+
39
+ use_cuda, num_gpus = setup_torch_training_env(True, False)
40
+
41
+
42
+ def setup_loader(ap, r, is_val=False, verbose=False):
43
+ if is_val and not c.run_eval:
44
+ loader = None
45
+ else:
46
+ dataset = MyDataset(
47
+ r,
48
+ c.text_cleaner,
49
+ compute_linear_spec=False,
50
+ meta_data=meta_data_eval if is_val else meta_data_train,
51
+ ap=ap,
52
+ tp=c.characters if 'characters' in c.keys() else None,
53
+ add_blank=c['add_blank'] if 'add_blank' in c.keys() else False,
54
+ batch_group_size=0 if is_val else c.batch_group_size *
55
+ c.batch_size,
56
+ min_seq_len=c.min_seq_len,
57
+ max_seq_len=c.max_seq_len,
58
+ phoneme_cache_path=c.phoneme_cache_path,
59
+ use_phonemes=c.use_phonemes,
60
+ phoneme_language=c.phoneme_language,
61
+ enable_eos_bos=c.enable_eos_bos_chars,
62
+ use_noise_augment=not is_val,
63
+ verbose=verbose,
64
+ speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None)
65
+
66
+ if c.use_phonemes and c.compute_input_seq_cache:
67
+ # precompute phonemes to have a better estimate of sequence lengths.
68
+ dataset.compute_input_seq(c.num_loader_workers)
69
+ dataset.sort_items()
70
+
71
+ sampler = DistributedSampler(dataset) if num_gpus > 1 else None
72
+ loader = DataLoader(
73
+ dataset,
74
+ batch_size=c.eval_batch_size if is_val else c.batch_size,
75
+ shuffle=False,
76
+ collate_fn=dataset.collate_fn,
77
+ drop_last=False,
78
+ sampler=sampler,
79
+ num_workers=c.num_val_loader_workers
80
+ if is_val else c.num_loader_workers,
81
+ pin_memory=False)
82
+ return loader
83
+
84
+
85
+ def format_data(data):
86
+ # setup input data
87
+ text_input = data[0]
88
+ text_lengths = data[1]
89
+ speaker_names = data[2]
90
+ mel_input = data[4].permute(0, 2, 1) # B x D x T
91
+ mel_lengths = data[5]
92
+ item_idx = data[7]
93
+ attn_mask = data[9]
94
+ avg_text_length = torch.mean(text_lengths.float())
95
+ avg_spec_length = torch.mean(mel_lengths.float())
96
+
97
+ if c.use_speaker_embedding:
98
+ if c.use_external_speaker_embedding_file:
99
+ # return precomputed embedding vector
100
+ speaker_c = data[8]
101
+ else:
102
+ # return speaker_id to be used by an embedding layer
103
+ speaker_c = [
104
+ speaker_mapping[speaker_name] for speaker_name in speaker_names
105
+ ]
106
+ speaker_c = torch.LongTensor(speaker_c)
107
+ else:
108
+ speaker_c = None
109
+ # compute durations from attention mask
110
+ durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2])
111
+ for idx, am in enumerate(attn_mask):
112
+ # compute raw durations
113
+ c_idxs = am[:, :text_lengths[idx], :mel_lengths[idx]].max(1)[1]
114
+ # c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True)
115
+ c_idxs, counts = torch.unique(c_idxs, return_counts=True)
116
+ dur = torch.ones([text_lengths[idx]]).to(counts.dtype)
117
+ dur[c_idxs] = counts
118
+ # smooth the durations and set any 0 duration to 1
119
+ # by cutting off from the largest duration indeces.
120
+ extra_frames = dur.sum() - mel_lengths[idx]
121
+ largest_idxs = torch.argsort(-dur)[:extra_frames]
122
+ dur[largest_idxs] -= 1
123
+ assert dur.sum() == mel_lengths[idx], f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
124
+ durations[idx, :text_lengths[idx]] = dur
125
+ # dispatch data to GPU
126
+ if use_cuda:
127
+ text_input = text_input.cuda(non_blocking=True)
128
+ text_lengths = text_lengths.cuda(non_blocking=True)
129
+ mel_input = mel_input.cuda(non_blocking=True)
130
+ mel_lengths = mel_lengths.cuda(non_blocking=True)
131
+ if speaker_c is not None:
132
+ speaker_c = speaker_c.cuda(non_blocking=True)
133
+ attn_mask = attn_mask.cuda(non_blocking=True)
134
+ durations = durations.cuda(non_blocking=True)
135
+ return text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
136
+ avg_text_length, avg_spec_length, attn_mask, durations, item_idx
137
+
138
+
139
+ def train(data_loader, model, criterion, optimizer, scheduler,
140
+ ap, global_step, epoch):
141
+
142
+ model.train()
143
+ epoch_time = 0
144
+ keep_avg = KeepAverage()
145
+ if use_cuda:
146
+ batch_n_iter = int(
147
+ len(data_loader.dataset) / (c.batch_size * num_gpus))
148
+ else:
149
+ batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
150
+ end_time = time.time()
151
+ c_logger.print_train_start()
152
+ scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
153
+ for num_iter, data in enumerate(data_loader):
154
+ start_time = time.time()
155
+
156
+ # format data
157
+ text_input, text_lengths, mel_targets, mel_lengths, speaker_c,\
158
+ avg_text_length, avg_spec_length, _, dur_target, _ = format_data(data)
159
+
160
+ loader_time = time.time() - end_time
161
+
162
+ global_step += 1
163
+ optimizer.zero_grad()
164
+
165
+ # forward pass model
166
+ with torch.cuda.amp.autocast(enabled=c.mixed_precision):
167
+ decoder_output, dur_output, alignments = model.forward(
168
+ text_input, text_lengths, mel_lengths, dur_target, g=speaker_c)
169
+
170
+ # compute loss
171
+ loss_dict = criterion(decoder_output, mel_targets, mel_lengths, dur_output, torch.log(1 + dur_target), text_lengths)
172
+
173
+ # backward pass with loss scaling
174
+ if c.mixed_precision:
175
+ scaler.scale(loss_dict['loss']).backward()
176
+ scaler.unscale_(optimizer)
177
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
178
+ c.grad_clip)
179
+ scaler.step(optimizer)
180
+ scaler.update()
181
+ else:
182
+ loss_dict['loss'].backward()
183
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
184
+ c.grad_clip)
185
+ optimizer.step()
186
+
187
+ # setup lr
188
+ if c.noam_schedule:
189
+ scheduler.step()
190
+
191
+ # current_lr
192
+ current_lr = optimizer.param_groups[0]['lr']
193
+
194
+ # compute alignment error (the lower the better )
195
+ align_error = 1 - alignment_diagonal_score(alignments, binary=True)
196
+ loss_dict['align_error'] = align_error
197
+
198
+ step_time = time.time() - start_time
199
+ epoch_time += step_time
200
+
201
+ # aggregate losses from processes
202
+ if num_gpus > 1:
203
+ loss_dict['loss_l1'] = reduce_tensor(loss_dict['loss_l1'].data, num_gpus)
204
+ loss_dict['loss_ssim'] = reduce_tensor(loss_dict['loss_ssim'].data, num_gpus)
205
+ loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus)
206
+ loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus)
207
+
208
+ # detach loss values
209
+ loss_dict_new = dict()
210
+ for key, value in loss_dict.items():
211
+ if isinstance(value, (int, float)):
212
+ loss_dict_new[key] = value
213
+ else:
214
+ loss_dict_new[key] = value.item()
215
+ loss_dict = loss_dict_new
216
+
217
+ # update avg stats
218
+ update_train_values = dict()
219
+ for key, value in loss_dict.items():
220
+ update_train_values['avg_' + key] = value
221
+ update_train_values['avg_loader_time'] = loader_time
222
+ update_train_values['avg_step_time'] = step_time
223
+ keep_avg.update_values(update_train_values)
224
+
225
+ # print training progress
226
+ if global_step % c.print_step == 0:
227
+ log_dict = {
228
+
229
+ "avg_spec_length": [avg_spec_length, 1], # value, precision
230
+ "avg_text_length": [avg_text_length, 1],
231
+ "step_time": [step_time, 4],
232
+ "loader_time": [loader_time, 2],
233
+ "current_lr": current_lr,
234
+ }
235
+ c_logger.print_train_step(batch_n_iter, num_iter, global_step,
236
+ log_dict, loss_dict, keep_avg.avg_values)
237
+
238
+ if args.rank == 0:
239
+ # Plot Training Iter Stats
240
+ # reduce TB load
241
+ if global_step % c.tb_plot_step == 0:
242
+ iter_stats = {
243
+ "lr": current_lr,
244
+ "grad_norm": grad_norm,
245
+ "step_time": step_time
246
+ }
247
+ iter_stats.update(loss_dict)
248
+ tb_logger.tb_train_iter_stats(global_step, iter_stats)
249
+
250
+ if global_step % c.save_step == 0:
251
+ if c.checkpoint:
252
+ # save model
253
+ save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH,
254
+ model_loss=loss_dict['loss'])
255
+
256
+ # wait all kernels to be completed
257
+ torch.cuda.synchronize()
258
+
259
+ # Diagnostic visualizations
260
+ idx = np.random.randint(mel_targets.shape[0])
261
+ pred_spec = decoder_output[idx].detach().data.cpu().numpy().T
262
+ gt_spec = mel_targets[idx].data.cpu().numpy().T
263
+ align_img = alignments[idx].data.cpu()
264
+
265
+ figures = {
266
+ "prediction": plot_spectrogram(pred_spec, ap),
267
+ "ground_truth": plot_spectrogram(gt_spec, ap),
268
+ "alignment": plot_alignment(align_img),
269
+ }
270
+
271
+ tb_logger.tb_train_figures(global_step, figures)
272
+
273
+ # Sample audio
274
+ train_audio = ap.inv_melspectrogram(pred_spec.T)
275
+ tb_logger.tb_train_audios(global_step,
276
+ {'TrainAudio': train_audio},
277
+ c.audio["sample_rate"])
278
+ end_time = time.time()
279
+
280
+ # print epoch stats
281
+ c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
282
+
283
+ # Plot Epoch Stats
284
+ if args.rank == 0:
285
+ epoch_stats = {"epoch_time": epoch_time}
286
+ epoch_stats.update(keep_avg.avg_values)
287
+ tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
288
+ if c.tb_model_param_stats:
289
+ tb_logger.tb_model_weights(model, global_step)
290
+ return keep_avg.avg_values, global_step
291
+
292
+
293
+ @torch.no_grad()
294
+ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
295
+ model.eval()
296
+ epoch_time = 0
297
+ keep_avg = KeepAverage()
298
+ c_logger.print_eval_start()
299
+ if data_loader is not None:
300
+ for num_iter, data in enumerate(data_loader):
301
+ start_time = time.time()
302
+
303
+ # format data
304
+ text_input, text_lengths, mel_targets, mel_lengths, speaker_c,\
305
+ _, _, _, dur_target, _ = format_data(data)
306
+
307
+ # forward pass model
308
+ with torch.cuda.amp.autocast(enabled=c.mixed_precision):
309
+ decoder_output, dur_output, alignments = model.forward(
310
+ text_input, text_lengths, mel_lengths, dur_target, g=speaker_c)
311
+
312
+ # compute loss
313
+ loss_dict = criterion(decoder_output, mel_targets, mel_lengths, dur_output, torch.log(1 + dur_target), text_lengths)
314
+
315
+ # step time
316
+ step_time = time.time() - start_time
317
+ epoch_time += step_time
318
+
319
+ # compute alignment score
320
+ align_error = 1 - alignment_diagonal_score(alignments, binary=True)
321
+ loss_dict['align_error'] = align_error
322
+
323
+ # aggregate losses from processes
324
+ if num_gpus > 1:
325
+ loss_dict['loss_l1'] = reduce_tensor(loss_dict['loss_l1'].data, num_gpus)
326
+ loss_dict['loss_ssim'] = reduce_tensor(loss_dict['loss_ssim'].data, num_gpus)
327
+ loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus)
328
+ loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus)
329
+
330
+ # detach loss values
331
+ loss_dict_new = dict()
332
+ for key, value in loss_dict.items():
333
+ if isinstance(value, (int, float)):
334
+ loss_dict_new[key] = value
335
+ else:
336
+ loss_dict_new[key] = value.item()
337
+ loss_dict = loss_dict_new
338
+
339
+ # update avg stats
340
+ update_train_values = dict()
341
+ for key, value in loss_dict.items():
342
+ update_train_values['avg_' + key] = value
343
+ keep_avg.update_values(update_train_values)
344
+
345
+ if c.print_eval:
346
+ c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
347
+
348
+ if args.rank == 0:
349
+ # Diagnostic visualizations
350
+ idx = np.random.randint(mel_targets.shape[0])
351
+ pred_spec = decoder_output[idx].detach().data.cpu().numpy().T
352
+ gt_spec = mel_targets[idx].data.cpu().numpy().T
353
+ align_img = alignments[idx].data.cpu()
354
+
355
+ eval_figures = {
356
+ "prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
357
+ "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
358
+ "alignment": plot_alignment(align_img, output_fig=False)
359
+ }
360
+
361
+ # Sample audio
362
+ eval_audio = ap.inv_melspectrogram(pred_spec.T)
363
+ tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
364
+ c.audio["sample_rate"])
365
+
366
+ # Plot Validation Stats
367
+ tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
368
+ tb_logger.tb_eval_figures(global_step, eval_figures)
369
+
370
+ if args.rank == 0 and epoch >= c.test_delay_epochs:
371
+ if c.test_sentences_file is None:
372
+ test_sentences = [
373
+ "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
374
+ "Be a voice, not an echo.",
375
+ "I'm sorry Dave. I'm afraid I can't do that.",
376
+ "This cake is great. It's so delicious and moist.",
377
+ "Prior to November 22, 1963."
378
+ ]
379
+ else:
380
+ with open(c.test_sentences_file, "r") as f:
381
+ test_sentences = [s.strip() for s in f.readlines()]
382
+
383
+ # test sentences
384
+ test_audios = {}
385
+ test_figures = {}
386
+ print(" | > Synthesizing test sentences")
387
+ if c.use_speaker_embedding:
388
+ if c.use_external_speaker_embedding_file:
389
+ speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping)-1)]]['embedding']
390
+ speaker_id = None
391
+ else:
392
+ speaker_id = 0
393
+ speaker_embedding = None
394
+ else:
395
+ speaker_id = None
396
+ speaker_embedding = None
397
+
398
+ style_wav = c.get("style_wav_for_test")
399
+ for idx, test_sentence in enumerate(test_sentences):
400
+ try:
401
+ wav, alignment, _, postnet_output, _, _ = synthesis(
402
+ model,
403
+ test_sentence,
404
+ c,
405
+ use_cuda,
406
+ ap,
407
+ speaker_id=speaker_id,
408
+ speaker_embedding=speaker_embedding,
409
+ style_wav=style_wav,
410
+ truncated=False,
411
+ enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument
412
+ use_griffin_lim=True,
413
+ do_trim_silence=False)
414
+
415
+ file_path = os.path.join(AUDIO_PATH, str(global_step))
416
+ os.makedirs(file_path, exist_ok=True)
417
+ file_path = os.path.join(file_path,
418
+ "TestSentence_{}.wav".format(idx))
419
+ ap.save_wav(wav, file_path)
420
+ test_audios['{}-audio'.format(idx)] = wav
421
+ test_figures['{}-prediction'.format(idx)] = plot_spectrogram(
422
+ postnet_output, ap)
423
+ test_figures['{}-alignment'.format(idx)] = plot_alignment(
424
+ alignment)
425
+ except: #pylint: disable=bare-except
426
+ print(" !! Error creating Test Sentence -", idx)
427
+ traceback.print_exc()
428
+ tb_logger.tb_test_audios(global_step, test_audios,
429
+ c.audio['sample_rate'])
430
+ tb_logger.tb_test_figures(global_step, test_figures)
431
+ return keep_avg.avg_values
432
+
433
+
434
+ # FIXME: move args definition/parsing inside of main?
435
+ def main(args): # pylint: disable=redefined-outer-name
436
+ # pylint: disable=global-variable-undefined
437
+ global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping
438
+ # Audio processor
439
+ ap = AudioProcessor(**c.audio)
440
+ if 'characters' in c.keys():
441
+ symbols, phonemes = make_symbols(**c.characters)
442
+
443
+ # DISTRUBUTED
444
+ if num_gpus > 1:
445
+ init_distributed(args.rank, num_gpus, args.group_id,
446
+ c.distributed["backend"], c.distributed["url"])
447
+ num_chars = len(phonemes) if c.use_phonemes else len(symbols)
448
+
449
+ # load data instances
450
+ meta_data_train, meta_data_eval = load_meta_data(c.datasets, eval_split=True)
451
+
452
+ # set the portion of the data used for training if set in config.json
453
+ if 'train_portion' in c.keys():
454
+ meta_data_train = meta_data_train[:int(len(meta_data_train) * c.train_portion)]
455
+ if 'eval_portion' in c.keys():
456
+ meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)]
457
+
458
+ # parse speakers
459
+ num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, OUT_PATH)
460
+
461
+ # setup model
462
+ model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim=speaker_embedding_dim)
463
+ optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0, betas=(0.9, 0.98), eps=1e-9)
464
+ criterion = SpeedySpeechLoss(c)
465
+
466
+ if args.restore_path:
467
+ checkpoint = torch.load(args.restore_path, map_location='cpu')
468
+ try:
469
+ # TODO: fix optimizer init, model.cuda() needs to be called before
470
+ # optimizer restore
471
+ optimizer.load_state_dict(checkpoint['optimizer'])
472
+ if c.reinit_layers:
473
+ raise RuntimeError
474
+ model.load_state_dict(checkpoint['model'])
475
+ except: #pylint: disable=bare-except
476
+ print(" > Partial model initialization.")
477
+ model_dict = model.state_dict()
478
+ model_dict = set_init_dict(model_dict, checkpoint['model'], c)
479
+ model.load_state_dict(model_dict)
480
+ del model_dict
481
+
482
+ for group in optimizer.param_groups:
483
+ group['initial_lr'] = c.lr
484
+ print(" > Model restored from step %d" % checkpoint['step'],
485
+ flush=True)
486
+ args.restore_step = checkpoint['step']
487
+ else:
488
+ args.restore_step = 0
489
+
490
+ if use_cuda:
491
+ model.cuda()
492
+ criterion.cuda()
493
+
494
+ # DISTRUBUTED
495
+ if num_gpus > 1:
496
+ model = DDP_th(model, device_ids=[args.rank])
497
+
498
+ if c.noam_schedule:
499
+ scheduler = NoamLR(optimizer,
500
+ warmup_steps=c.warmup_steps,
501
+ last_epoch=args.restore_step - 1)
502
+ else:
503
+ scheduler = None
504
+
505
+ num_params = count_parameters(model)
506
+ print("\n > Model has {} parameters".format(num_params), flush=True)
507
+
508
+ if 'best_loss' not in locals():
509
+ best_loss = float('inf')
510
+
511
+ # define dataloaders
512
+ train_loader = setup_loader(ap, 1, is_val=False, verbose=True)
513
+ eval_loader = setup_loader(ap, 1, is_val=True, verbose=True)
514
+
515
+ global_step = args.restore_step
516
+ for epoch in range(0, c.epochs):
517
+ c_logger.print_epoch_start(epoch, c.epochs)
518
+ train_avg_loss_dict, global_step = train(train_loader, model, criterion, optimizer,
519
+ scheduler, ap, global_step,
520
+ epoch)
521
+ eval_avg_loss_dict = evaluate(eval_loader , model, criterion, ap, global_step, epoch)
522
+ c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
523
+ target_loss = train_avg_loss_dict['avg_loss']
524
+ if c.run_eval:
525
+ target_loss = eval_avg_loss_dict['avg_loss']
526
+ best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
527
+ OUT_PATH)
528
+
529
+
530
+ if __name__ == '__main__':
531
+ parser = argparse.ArgumentParser()
532
+ parser.add_argument(
533
+ '--continue_path',
534
+ type=str,
535
+ help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
536
+ default='',
537
+ required='--config_path' not in sys.argv)
538
+ parser.add_argument(
539
+ '--restore_path',
540
+ type=str,
541
+ help='Model file to be restored. Use to finetune a model.',
542
+ default='')
543
+ parser.add_argument(
544
+ '--config_path',
545
+ type=str,
546
+ help='Path to config file for training.',
547
+ required='--continue_path' not in sys.argv
548
+ )
549
+ parser.add_argument('--debug',
550
+ type=bool,
551
+ default=False,
552
+ help='Do not verify commit integrity to run training.')
553
+
554
+ # DISTRUBUTED
555
+ parser.add_argument(
556
+ '--rank',
557
+ type=int,
558
+ default=0,
559
+ help='DISTRIBUTED: process rank for distributed training.')
560
+ parser.add_argument('--group_id',
561
+ type=str,
562
+ default="",
563
+ help='DISTRIBUTED: process group id.')
564
+ args = parser.parse_args()
565
+
566
+ if args.continue_path != '':
567
+ args.output_path = args.continue_path
568
+ args.config_path = os.path.join(args.continue_path, 'config.json')
569
+ list_of_files = glob.glob(args.continue_path + "/*.pth.tar") # * means all if need specific format then *.csv
570
+ latest_model_file = max(list_of_files, key=os.path.getctime)
571
+ args.restore_path = latest_model_file
572
+ print(f" > Training continues for {args.restore_path}")
573
+
574
+ # setup output paths and read configs
575
+ c = load_config(args.config_path)
576
+ # check_config(c)
577
+ check_config_tts(c)
578
+ _ = os.path.dirname(os.path.realpath(__file__))
579
+
580
+ if c.mixed_precision:
581
+ print(" > Mixed precision enabled.")
582
+
583
+ OUT_PATH = args.continue_path
584
+ if args.continue_path == '':
585
+ OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug)
586
+
587
+ AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
588
+
589
+ c_logger = ConsoleLogger()
590
+
591
+ if args.rank == 0:
592
+ os.makedirs(AUDIO_PATH, exist_ok=True)
593
+ new_fields = {}
594
+ if args.restore_path:
595
+ new_fields["restore_path"] = args.restore_path
596
+ new_fields["github_branch"] = get_git_branch()
597
+ copy_model_files(c, args.config_path, OUT_PATH, new_fields)
598
+ os.chmod(AUDIO_PATH, 0o775)
599
+ os.chmod(OUT_PATH, 0o775)
600
+
601
+ LOG_DIR = OUT_PATH
602
+ tb_logger = TensorboardLogger(LOG_DIR, model_name='TTS')
603
+
604
+ # write model desc to tensorboard
605
+ tb_logger.tb_add_text('model-description', c['run_description'], 0)
606
+
607
+ try:
608
+ main(args)
609
+ except KeyboardInterrupt:
610
+ remove_experiment_folder(OUT_PATH)
611
+ try:
612
+ sys.exit(0)
613
+ except SystemExit:
614
+ os._exit(0) # pylint: disable=protected-access
615
+ except Exception: # pylint: disable=broad-except
616
+ remove_experiment_folder(OUT_PATH)
617
+ traceback.print_exc()
618
+ sys.exit(1)
TTS/bin/train_tacotron.py ADDED
@@ -0,0 +1,731 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import argparse
5
+ import glob
6
+ import os
7
+ import sys
8
+ import time
9
+ import traceback
10
+ from random import randrange
11
+
12
+ import numpy as np
13
+ import torch
14
+ from torch.utils.data import DataLoader
15
+ from TTS.tts.datasets.preprocess import load_meta_data
16
+ from TTS.tts.datasets.TTSDataset import MyDataset
17
+ from TTS.tts.layers.losses import TacotronLoss
18
+ from TTS.tts.utils.generic_utils import check_config_tts, setup_model
19
+ from TTS.tts.utils.io import save_best_model, save_checkpoint
20
+ from TTS.tts.utils.measures import alignment_diagonal_score
21
+ from TTS.tts.utils.speakers import parse_speakers
22
+ from TTS.tts.utils.synthesis import synthesis
23
+ from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
24
+ from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
25
+ from TTS.utils.audio import AudioProcessor
26
+ from TTS.utils.console_logger import ConsoleLogger
27
+ from TTS.utils.distribute import (DistributedSampler, apply_gradient_allreduce,
28
+ init_distributed, reduce_tensor)
29
+ from TTS.utils.generic_utils import (KeepAverage, count_parameters,
30
+ create_experiment_folder, get_git_branch,
31
+ remove_experiment_folder, set_init_dict)
32
+ from TTS.utils.io import copy_model_files, load_config
33
+ from TTS.utils.radam import RAdam
34
+ from TTS.utils.tensorboard_logger import TensorboardLogger
35
+ from TTS.utils.training import (NoamLR, adam_weight_decay, check_update,
36
+ gradual_training_scheduler, set_weight_decay,
37
+ setup_torch_training_env)
38
+
39
+ use_cuda, num_gpus = setup_torch_training_env(True, False)
40
+
41
+
42
+ def setup_loader(ap, r, is_val=False, verbose=False, dataset=None):
43
+ if is_val and not c.run_eval:
44
+ loader = None
45
+ else:
46
+ if dataset is None:
47
+ dataset = MyDataset(
48
+ r,
49
+ c.text_cleaner,
50
+ compute_linear_spec=c.model.lower() == 'tacotron',
51
+ meta_data=meta_data_eval if is_val else meta_data_train,
52
+ ap=ap,
53
+ tp=c.characters if 'characters' in c.keys() else None,
54
+ add_blank=c['add_blank'] if 'add_blank' in c.keys() else False,
55
+ batch_group_size=0 if is_val else c.batch_group_size *
56
+ c.batch_size,
57
+ min_seq_len=c.min_seq_len,
58
+ max_seq_len=c.max_seq_len,
59
+ phoneme_cache_path=c.phoneme_cache_path,
60
+ use_phonemes=c.use_phonemes,
61
+ phoneme_language=c.phoneme_language,
62
+ enable_eos_bos=c.enable_eos_bos_chars,
63
+ verbose=verbose,
64
+ speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None)
65
+
66
+ if c.use_phonemes and c.compute_input_seq_cache:
67
+ # precompute phonemes to have a better estimate of sequence lengths.
68
+ dataset.compute_input_seq(c.num_loader_workers)
69
+ dataset.sort_items()
70
+
71
+ sampler = DistributedSampler(dataset) if num_gpus > 1 else None
72
+ loader = DataLoader(
73
+ dataset,
74
+ batch_size=c.eval_batch_size if is_val else c.batch_size,
75
+ shuffle=False,
76
+ collate_fn=dataset.collate_fn,
77
+ drop_last=False,
78
+ sampler=sampler,
79
+ num_workers=c.num_val_loader_workers
80
+ if is_val else c.num_loader_workers,
81
+ pin_memory=False)
82
+ return loader
83
+
84
+ def format_data(data):
85
+ # setup input data
86
+ text_input = data[0]
87
+ text_lengths = data[1]
88
+ speaker_names = data[2]
89
+ linear_input = data[3] if c.model in ["Tacotron"] else None
90
+ mel_input = data[4]
91
+ mel_lengths = data[5]
92
+ stop_targets = data[6]
93
+ max_text_length = torch.max(text_lengths.float())
94
+ max_spec_length = torch.max(mel_lengths.float())
95
+
96
+ if c.use_speaker_embedding:
97
+ if c.use_external_speaker_embedding_file:
98
+ speaker_embeddings = data[8]
99
+ speaker_ids = None
100
+ else:
101
+ speaker_ids = [
102
+ speaker_mapping[speaker_name] for speaker_name in speaker_names
103
+ ]
104
+ speaker_ids = torch.LongTensor(speaker_ids)
105
+ speaker_embeddings = None
106
+ else:
107
+ speaker_embeddings = None
108
+ speaker_ids = None
109
+
110
+
111
+ # set stop targets view, we predict a single stop token per iteration.
112
+ stop_targets = stop_targets.view(text_input.shape[0],
113
+ stop_targets.size(1) // c.r, -1)
114
+ stop_targets = (stop_targets.sum(2) >
115
+ 0.0).unsqueeze(2).float().squeeze(2)
116
+
117
+ # dispatch data to GPU
118
+ if use_cuda:
119
+ text_input = text_input.cuda(non_blocking=True)
120
+ text_lengths = text_lengths.cuda(non_blocking=True)
121
+ mel_input = mel_input.cuda(non_blocking=True)
122
+ mel_lengths = mel_lengths.cuda(non_blocking=True)
123
+ linear_input = linear_input.cuda(non_blocking=True) if c.model in ["Tacotron"] else None
124
+ stop_targets = stop_targets.cuda(non_blocking=True)
125
+ if speaker_ids is not None:
126
+ speaker_ids = speaker_ids.cuda(non_blocking=True)
127
+ if speaker_embeddings is not None:
128
+ speaker_embeddings = speaker_embeddings.cuda(non_blocking=True)
129
+
130
+ return text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, max_text_length, max_spec_length
131
+
132
+
133
+ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
134
+ ap, global_step, epoch, scaler, scaler_st):
135
+ model.train()
136
+ epoch_time = 0
137
+ keep_avg = KeepAverage()
138
+ if use_cuda:
139
+ batch_n_iter = int(
140
+ len(data_loader.dataset) / (c.batch_size * num_gpus))
141
+ else:
142
+ batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
143
+ end_time = time.time()
144
+ c_logger.print_train_start()
145
+ for num_iter, data in enumerate(data_loader):
146
+ start_time = time.time()
147
+
148
+ # format data
149
+ text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, max_text_length, max_spec_length = format_data(data)
150
+ loader_time = time.time() - end_time
151
+
152
+ global_step += 1
153
+
154
+ # setup lr
155
+ if c.noam_schedule:
156
+ scheduler.step()
157
+
158
+ optimizer.zero_grad()
159
+ if optimizer_st:
160
+ optimizer_st.zero_grad()
161
+
162
+ with torch.cuda.amp.autocast(enabled=c.mixed_precision):
163
+ # forward pass model
164
+ if c.bidirectional_decoder or c.double_decoder_consistency:
165
+ decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
166
+ text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
167
+ else:
168
+ decoder_output, postnet_output, alignments, stop_tokens = model(
169
+ text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
170
+ decoder_backward_output = None
171
+ alignments_backward = None
172
+
173
+ # set the [alignment] lengths wrt reduction factor for guided attention
174
+ if mel_lengths.max() % model.decoder.r != 0:
175
+ alignment_lengths = (mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r))) // model.decoder.r
176
+ else:
177
+ alignment_lengths = mel_lengths // model.decoder.r
178
+
179
+ # compute loss
180
+ loss_dict = criterion(postnet_output, decoder_output, mel_input,
181
+ linear_input, stop_tokens, stop_targets,
182
+ mel_lengths, decoder_backward_output,
183
+ alignments, alignment_lengths, alignments_backward,
184
+ text_lengths)
185
+
186
+ # check nan loss
187
+ if torch.isnan(loss_dict['loss']).any():
188
+ raise RuntimeError(f'Detected NaN loss at step {global_step}.')
189
+
190
+ # optimizer step
191
+ if c.mixed_precision:
192
+ # model optimizer step in mixed precision mode
193
+ scaler.scale(loss_dict['loss']).backward()
194
+ scaler.unscale_(optimizer)
195
+ optimizer, current_lr = adam_weight_decay(optimizer)
196
+ grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True)
197
+ scaler.step(optimizer)
198
+ scaler.update()
199
+
200
+ # stopnet optimizer step
201
+ if c.separate_stopnet:
202
+ scaler_st.scale( loss_dict['stopnet_loss']).backward()
203
+ scaler.unscale_(optimizer_st)
204
+ optimizer_st, _ = adam_weight_decay(optimizer_st)
205
+ grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
206
+ scaler_st.step(optimizer)
207
+ scaler_st.update()
208
+ else:
209
+ grad_norm_st = 0
210
+ else:
211
+ # main model optimizer step
212
+ loss_dict['loss'].backward()
213
+ optimizer, current_lr = adam_weight_decay(optimizer)
214
+ grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True)
215
+ optimizer.step()
216
+
217
+ # stopnet optimizer step
218
+ if c.separate_stopnet:
219
+ loss_dict['stopnet_loss'].backward()
220
+ optimizer_st, _ = adam_weight_decay(optimizer_st)
221
+ grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
222
+ optimizer_st.step()
223
+ else:
224
+ grad_norm_st = 0
225
+
226
+ # compute alignment error (the lower the better )
227
+ align_error = 1 - alignment_diagonal_score(alignments)
228
+ loss_dict['align_error'] = align_error
229
+
230
+ step_time = time.time() - start_time
231
+ epoch_time += step_time
232
+
233
+ # aggregate losses from processes
234
+ if num_gpus > 1:
235
+ loss_dict['postnet_loss'] = reduce_tensor(loss_dict['postnet_loss'].data, num_gpus)
236
+ loss_dict['decoder_loss'] = reduce_tensor(loss_dict['decoder_loss'].data, num_gpus)
237
+ loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus)
238
+ loss_dict['stopnet_loss'] = reduce_tensor(loss_dict['stopnet_loss'].data, num_gpus) if c.stopnet else loss_dict['stopnet_loss']
239
+
240
+ # detach loss values
241
+ loss_dict_new = dict()
242
+ for key, value in loss_dict.items():
243
+ if isinstance(value, (int, float)):
244
+ loss_dict_new[key] = value
245
+ else:
246
+ loss_dict_new[key] = value.item()
247
+ loss_dict = loss_dict_new
248
+
249
+ # update avg stats
250
+ update_train_values = dict()
251
+ for key, value in loss_dict.items():
252
+ update_train_values['avg_' + key] = value
253
+ update_train_values['avg_loader_time'] = loader_time
254
+ update_train_values['avg_step_time'] = step_time
255
+ keep_avg.update_values(update_train_values)
256
+
257
+ # print training progress
258
+ if global_step % c.print_step == 0:
259
+ log_dict = {
260
+ "max_spec_length": [max_spec_length, 1], # value, precision
261
+ "max_text_length": [max_text_length, 1],
262
+ "step_time": [step_time, 4],
263
+ "loader_time": [loader_time, 2],
264
+ "current_lr": current_lr,
265
+ }
266
+ c_logger.print_train_step(batch_n_iter, num_iter, global_step,
267
+ log_dict, loss_dict, keep_avg.avg_values)
268
+
269
+ if args.rank == 0:
270
+ # Plot Training Iter Stats
271
+ # reduce TB load
272
+ if global_step % c.tb_plot_step == 0:
273
+ iter_stats = {
274
+ "lr": current_lr,
275
+ "grad_norm": grad_norm,
276
+ "grad_norm_st": grad_norm_st,
277
+ "step_time": step_time
278
+ }
279
+ iter_stats.update(loss_dict)
280
+ tb_logger.tb_train_iter_stats(global_step, iter_stats)
281
+
282
+ if global_step % c.save_step == 0:
283
+ if c.checkpoint:
284
+ # save model
285
+ save_checkpoint(model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH,
286
+ optimizer_st=optimizer_st,
287
+ model_loss=loss_dict['postnet_loss'],
288
+ scaler=scaler.state_dict() if c.mixed_precision else None)
289
+
290
+ # Diagnostic visualizations
291
+ const_spec = postnet_output[0].data.cpu().numpy()
292
+ gt_spec = linear_input[0].data.cpu().numpy() if c.model in [
293
+ "Tacotron", "TacotronGST"
294
+ ] else mel_input[0].data.cpu().numpy()
295
+ align_img = alignments[0].data.cpu().numpy()
296
+
297
+ figures = {
298
+ "prediction": plot_spectrogram(const_spec, ap, output_fig=False),
299
+ "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
300
+ "alignment": plot_alignment(align_img, output_fig=False),
301
+ }
302
+
303
+ if c.bidirectional_decoder or c.double_decoder_consistency:
304
+ figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False)
305
+
306
+ tb_logger.tb_train_figures(global_step, figures)
307
+
308
+ # Sample audio
309
+ if c.model in ["Tacotron", "TacotronGST"]:
310
+ train_audio = ap.inv_spectrogram(const_spec.T)
311
+ else:
312
+ train_audio = ap.inv_melspectrogram(const_spec.T)
313
+ tb_logger.tb_train_audios(global_step,
314
+ {'TrainAudio': train_audio},
315
+ c.audio["sample_rate"])
316
+ end_time = time.time()
317
+
318
+ # print epoch stats
319
+ c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
320
+
321
+ # Plot Epoch Stats
322
+ if args.rank == 0:
323
+ epoch_stats = {"epoch_time": epoch_time}
324
+ epoch_stats.update(keep_avg.avg_values)
325
+ tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
326
+ if c.tb_model_param_stats:
327
+ tb_logger.tb_model_weights(model, global_step)
328
+ return keep_avg.avg_values, global_step
329
+
330
+
331
+ @torch.no_grad()
332
+ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
333
+ model.eval()
334
+ epoch_time = 0
335
+ keep_avg = KeepAverage()
336
+ c_logger.print_eval_start()
337
+ if data_loader is not None:
338
+ for num_iter, data in enumerate(data_loader):
339
+ start_time = time.time()
340
+
341
+ # format data
342
+ text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, _, _ = format_data(data)
343
+ assert mel_input.shape[1] % model.decoder.r == 0
344
+
345
+ # forward pass model
346
+ if c.bidirectional_decoder or c.double_decoder_consistency:
347
+ decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
348
+ text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
349
+ else:
350
+ decoder_output, postnet_output, alignments, stop_tokens = model(
351
+ text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
352
+ decoder_backward_output = None
353
+ alignments_backward = None
354
+
355
+ # set the alignment lengths wrt reduction factor for guided attention
356
+ if mel_lengths.max() % model.decoder.r != 0:
357
+ alignment_lengths = (mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r))) // model.decoder.r
358
+ else:
359
+ alignment_lengths = mel_lengths // model.decoder.r
360
+
361
+ # compute loss
362
+ loss_dict = criterion(postnet_output, decoder_output, mel_input,
363
+ linear_input, stop_tokens, stop_targets,
364
+ mel_lengths, decoder_backward_output,
365
+ alignments, alignment_lengths, alignments_backward,
366
+ text_lengths)
367
+
368
+ # step time
369
+ step_time = time.time() - start_time
370
+ epoch_time += step_time
371
+
372
+ # compute alignment score
373
+ align_error = 1 - alignment_diagonal_score(alignments)
374
+ loss_dict['align_error'] = align_error
375
+
376
+ # aggregate losses from processes
377
+ if num_gpus > 1:
378
+ loss_dict['postnet_loss'] = reduce_tensor(loss_dict['postnet_loss'].data, num_gpus)
379
+ loss_dict['decoder_loss'] = reduce_tensor(loss_dict['decoder_loss'].data, num_gpus)
380
+ if c.stopnet:
381
+ loss_dict['stopnet_loss'] = reduce_tensor(loss_dict['stopnet_loss'].data, num_gpus)
382
+
383
+ # detach loss values
384
+ loss_dict_new = dict()
385
+ for key, value in loss_dict.items():
386
+ if isinstance(value, (int, float)):
387
+ loss_dict_new[key] = value
388
+ else:
389
+ loss_dict_new[key] = value.item()
390
+ loss_dict = loss_dict_new
391
+
392
+ # update avg stats
393
+ update_train_values = dict()
394
+ for key, value in loss_dict.items():
395
+ update_train_values['avg_' + key] = value
396
+ keep_avg.update_values(update_train_values)
397
+
398
+ if c.print_eval:
399
+ c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
400
+
401
+ if args.rank == 0:
402
+ # Diagnostic visualizations
403
+ idx = np.random.randint(mel_input.shape[0])
404
+ const_spec = postnet_output[idx].data.cpu().numpy()
405
+ gt_spec = linear_input[idx].data.cpu().numpy() if c.model in [
406
+ "Tacotron", "TacotronGST"
407
+ ] else mel_input[idx].data.cpu().numpy()
408
+ align_img = alignments[idx].data.cpu().numpy()
409
+
410
+ eval_figures = {
411
+ "prediction": plot_spectrogram(const_spec, ap, output_fig=False),
412
+ "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
413
+ "alignment": plot_alignment(align_img, output_fig=False)
414
+ }
415
+
416
+ # Sample audio
417
+ if c.model in ["Tacotron", "TacotronGST"]:
418
+ eval_audio = ap.inv_spectrogram(const_spec.T)
419
+ else:
420
+ eval_audio = ap.inv_melspectrogram(const_spec.T)
421
+ tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
422
+ c.audio["sample_rate"])
423
+
424
+ # Plot Validation Stats
425
+
426
+ if c.bidirectional_decoder or c.double_decoder_consistency:
427
+ align_b_img = alignments_backward[idx].data.cpu().numpy()
428
+ eval_figures['alignment2'] = plot_alignment(align_b_img, output_fig=False)
429
+ tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
430
+ tb_logger.tb_eval_figures(global_step, eval_figures)
431
+
432
+ if args.rank == 0 and epoch > c.test_delay_epochs:
433
+ if c.test_sentences_file is None:
434
+ test_sentences = [
435
+ "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
436
+ "Be a voice, not an echo.",
437
+ "I'm sorry Dave. I'm afraid I can't do that.",
438
+ "This cake is great. It's so delicious and moist.",
439
+ "Prior to November 22, 1963."
440
+ ]
441
+ else:
442
+ with open(c.test_sentences_file, "r") as f:
443
+ test_sentences = [s.strip() for s in f.readlines()]
444
+
445
+ # test sentences
446
+ test_audios = {}
447
+ test_figures = {}
448
+ print(" | > Synthesizing test sentences")
449
+ speaker_id = 0 if c.use_speaker_embedding else None
450
+ speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping)-1)]]['embedding'] if c.use_external_speaker_embedding_file and c.use_speaker_embedding else None
451
+ style_wav = c.get("gst_style_input")
452
+ if style_wav is None and c.use_gst:
453
+ # inicialize GST with zero dict.
454
+ style_wav = {}
455
+ print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!")
456
+ for i in range(c.gst['gst_style_tokens']):
457
+ style_wav[str(i)] = 0
458
+ style_wav = c.get("gst_style_input")
459
+ for idx, test_sentence in enumerate(test_sentences):
460
+ try:
461
+ wav, alignment, decoder_output, postnet_output, stop_tokens, _ = synthesis(
462
+ model,
463
+ test_sentence,
464
+ c,
465
+ use_cuda,
466
+ ap,
467
+ speaker_id=speaker_id,
468
+ speaker_embedding=speaker_embedding,
469
+ style_wav=style_wav,
470
+ truncated=False,
471
+ enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument
472
+ use_griffin_lim=True,
473
+ do_trim_silence=False)
474
+
475
+ file_path = os.path.join(AUDIO_PATH, str(global_step))
476
+ os.makedirs(file_path, exist_ok=True)
477
+ file_path = os.path.join(file_path,
478
+ "TestSentence_{}.wav".format(idx))
479
+ ap.save_wav(wav, file_path)
480
+ test_audios['{}-audio'.format(idx)] = wav
481
+ test_figures['{}-prediction'.format(idx)] = plot_spectrogram(
482
+ postnet_output, ap, output_fig=False)
483
+ test_figures['{}-alignment'.format(idx)] = plot_alignment(
484
+ alignment, output_fig=False)
485
+ except: #pylint: disable=bare-except
486
+ print(" !! Error creating Test Sentence -", idx)
487
+ traceback.print_exc()
488
+ tb_logger.tb_test_audios(global_step, test_audios,
489
+ c.audio['sample_rate'])
490
+ tb_logger.tb_test_figures(global_step, test_figures)
491
+ return keep_avg.avg_values
492
+
493
+
494
+ # FIXME: move args definition/parsing inside of main?
495
+ def main(args): # pylint: disable=redefined-outer-name
496
+ # pylint: disable=global-variable-undefined
497
+ global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping
498
+ # Audio processor
499
+ ap = AudioProcessor(**c.audio)
500
+ if 'characters' in c.keys():
501
+ symbols, phonemes = make_symbols(**c.characters)
502
+
503
+ # DISTRUBUTED
504
+ if num_gpus > 1:
505
+ init_distributed(args.rank, num_gpus, args.group_id,
506
+ c.distributed["backend"], c.distributed["url"])
507
+ num_chars = len(phonemes) if c.use_phonemes else len(symbols)
508
+
509
+ # load data instances
510
+ meta_data_train, meta_data_eval = load_meta_data(c.datasets)
511
+
512
+ # set the portion of the data used for training
513
+ if 'train_portion' in c.keys():
514
+ meta_data_train = meta_data_train[:int(len(meta_data_train) * c.train_portion)]
515
+ if 'eval_portion' in c.keys():
516
+ meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)]
517
+
518
+ # parse speakers
519
+ num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, OUT_PATH)
520
+
521
+ model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim)
522
+
523
+ # scalers for mixed precision training
524
+ scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
525
+ scaler_st = torch.cuda.amp.GradScaler() if c.mixed_precision and c.separate_stopnet else None
526
+
527
+ params = set_weight_decay(model, c.wd)
528
+ optimizer = RAdam(params, lr=c.lr, weight_decay=0)
529
+ if c.stopnet and c.separate_stopnet:
530
+ optimizer_st = RAdam(model.decoder.stopnet.parameters(),
531
+ lr=c.lr,
532
+ weight_decay=0)
533
+ else:
534
+ optimizer_st = None
535
+
536
+ # setup criterion
537
+ criterion = TacotronLoss(c, stopnet_pos_weight=10.0, ga_sigma=0.4)
538
+
539
+ if args.restore_path:
540
+ checkpoint = torch.load(args.restore_path, map_location='cpu')
541
+ try:
542
+ print(" > Restoring Model.")
543
+ model.load_state_dict(checkpoint['model'])
544
+ # optimizer restore
545
+ print(" > Restoring Optimizer.")
546
+ optimizer.load_state_dict(checkpoint['optimizer'])
547
+ if "scaler" in checkpoint and c.mixed_precision:
548
+ print(" > Restoring AMP Scaler...")
549
+ scaler.load_state_dict(checkpoint["scaler"])
550
+ if c.reinit_layers:
551
+ raise RuntimeError
552
+ except (KeyError, RuntimeError):
553
+ print(" > Partial model initialization.")
554
+ model_dict = model.state_dict()
555
+ model_dict = set_init_dict(model_dict, checkpoint['model'], c)
556
+ # torch.save(model_dict, os.path.join(OUT_PATH, 'state_dict.pt'))
557
+ # print("State Dict saved for debug in: ", os.path.join(OUT_PATH, 'state_dict.pt'))
558
+ model.load_state_dict(model_dict)
559
+ del model_dict
560
+
561
+ for group in optimizer.param_groups:
562
+ group['lr'] = c.lr
563
+ print(" > Model restored from step %d" % checkpoint['step'],
564
+ flush=True)
565
+ args.restore_step = checkpoint['step']
566
+ else:
567
+ args.restore_step = 0
568
+
569
+ if use_cuda:
570
+ model.cuda()
571
+ criterion.cuda()
572
+
573
+ # DISTRUBUTED
574
+ if num_gpus > 1:
575
+ model = apply_gradient_allreduce(model)
576
+
577
+ if c.noam_schedule:
578
+ scheduler = NoamLR(optimizer,
579
+ warmup_steps=c.warmup_steps,
580
+ last_epoch=args.restore_step - 1)
581
+ else:
582
+ scheduler = None
583
+
584
+ num_params = count_parameters(model)
585
+ print("\n > Model has {} parameters".format(num_params), flush=True)
586
+
587
+ if 'best_loss' not in locals():
588
+ best_loss = float('inf')
589
+
590
+ # define data loaders
591
+ train_loader = setup_loader(ap,
592
+ model.decoder.r,
593
+ is_val=False,
594
+ verbose=True)
595
+ eval_loader = setup_loader(ap, model.decoder.r, is_val=True)
596
+
597
+ global_step = args.restore_step
598
+ for epoch in range(0, c.epochs):
599
+ c_logger.print_epoch_start(epoch, c.epochs)
600
+ # set gradual training
601
+ if c.gradual_training is not None:
602
+ r, c.batch_size = gradual_training_scheduler(global_step, c)
603
+ c.r = r
604
+ model.decoder.set_r(r)
605
+ if c.bidirectional_decoder:
606
+ model.decoder_backward.set_r(r)
607
+ train_loader.dataset.outputs_per_step = r
608
+ eval_loader.dataset.outputs_per_step = r
609
+ train_loader = setup_loader(ap,
610
+ model.decoder.r,
611
+ is_val=False,
612
+ dataset=train_loader.dataset)
613
+ eval_loader = setup_loader(ap,
614
+ model.decoder.r,
615
+ is_val=True,
616
+ dataset=eval_loader.dataset)
617
+ print("\n > Number of output frames:", model.decoder.r)
618
+ # train one epoch
619
+ train_avg_loss_dict, global_step = train(train_loader, model,
620
+ criterion, optimizer,
621
+ optimizer_st, scheduler, ap,
622
+ global_step, epoch, scaler,
623
+ scaler_st)
624
+ # eval one epoch
625
+ eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap,
626
+ global_step, epoch)
627
+ c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
628
+ target_loss = train_avg_loss_dict['avg_postnet_loss']
629
+ if c.run_eval:
630
+ target_loss = eval_avg_loss_dict['avg_postnet_loss']
631
+ best_loss = save_best_model(
632
+ target_loss,
633
+ best_loss,
634
+ model,
635
+ optimizer,
636
+ global_step,
637
+ epoch,
638
+ c.r,
639
+ OUT_PATH,
640
+ scaler=scaler.state_dict() if c.mixed_precision else None)
641
+
642
+
643
+ if __name__ == '__main__':
644
+ parser = argparse.ArgumentParser()
645
+ parser.add_argument(
646
+ '--continue_path',
647
+ type=str,
648
+ help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
649
+ default='',
650
+ required='--config_path' not in sys.argv)
651
+ parser.add_argument(
652
+ '--restore_path',
653
+ type=str,
654
+ help='Model file to be restored. Use to finetune a model.',
655
+ default='')
656
+ parser.add_argument(
657
+ '--config_path',
658
+ type=str,
659
+ help='Path to config file for training.',
660
+ required='--continue_path' not in sys.argv
661
+ )
662
+ parser.add_argument('--debug',
663
+ type=bool,
664
+ default=False,
665
+ help='Do not verify commit integrity to run training.')
666
+
667
+ # DISTRUBUTED
668
+ parser.add_argument(
669
+ '--rank',
670
+ type=int,
671
+ default=0,
672
+ help='DISTRIBUTED: process rank for distributed training.')
673
+ parser.add_argument('--group_id',
674
+ type=str,
675
+ default="",
676
+ help='DISTRIBUTED: process group id.')
677
+ args = parser.parse_args()
678
+
679
+ if args.continue_path != '':
680
+ print(f" > Training continues for {args.continue_path}")
681
+ args.output_path = args.continue_path
682
+ args.config_path = os.path.join(args.continue_path, 'config.json')
683
+ list_of_files = glob.glob(args.continue_path + "/*.pth.tar") # * means all if need specific format then *.csv
684
+ latest_model_file = max(list_of_files, key=os.path.getctime)
685
+ args.restore_path = latest_model_file
686
+
687
+ # setup output paths and read configs
688
+ c = load_config(args.config_path)
689
+ check_config_tts(c)
690
+ _ = os.path.dirname(os.path.realpath(__file__))
691
+
692
+ if c.mixed_precision:
693
+ print(" > Mixed precision mode is ON")
694
+
695
+ OUT_PATH = args.continue_path
696
+ if args.continue_path == '':
697
+ OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug)
698
+
699
+ AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
700
+
701
+ c_logger = ConsoleLogger()
702
+
703
+ if args.rank == 0:
704
+ os.makedirs(AUDIO_PATH, exist_ok=True)
705
+ new_fields = {}
706
+ if args.restore_path:
707
+ new_fields["restore_path"] = args.restore_path
708
+ new_fields["github_branch"] = get_git_branch()
709
+ copy_model_files(c, args.config_path,
710
+ OUT_PATH, new_fields)
711
+ os.chmod(AUDIO_PATH, 0o775)
712
+ os.chmod(OUT_PATH, 0o775)
713
+
714
+ LOG_DIR = OUT_PATH
715
+ tb_logger = TensorboardLogger(LOG_DIR, model_name='TTS')
716
+
717
+ # write model desc to tensorboard
718
+ tb_logger.tb_add_text('model-description', c['run_description'], 0)
719
+
720
+ try:
721
+ main(args)
722
+ except KeyboardInterrupt:
723
+ remove_experiment_folder(OUT_PATH)
724
+ try:
725
+ sys.exit(0)
726
+ except SystemExit:
727
+ os._exit(0) # pylint: disable=protected-access
728
+ except Exception: # pylint: disable=broad-except
729
+ remove_experiment_folder(OUT_PATH)
730
+ traceback.print_exc()
731
+ sys.exit(1)
TTS/bin/train_vocoder_gan.py ADDED
@@ -0,0 +1,664 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import os
4
+ import sys
5
+ import time
6
+ import traceback
7
+ from inspect import signature
8
+
9
+ import torch
10
+ from torch.utils.data import DataLoader
11
+ from TTS.utils.audio import AudioProcessor
12
+ from TTS.utils.console_logger import ConsoleLogger
13
+ from TTS.utils.generic_utils import (KeepAverage, count_parameters,
14
+ create_experiment_folder, get_git_branch,
15
+ remove_experiment_folder, set_init_dict)
16
+ from TTS.utils.io import copy_model_files, load_config
17
+ from TTS.utils.radam import RAdam
18
+ from TTS.utils.tensorboard_logger import TensorboardLogger
19
+ from TTS.utils.training import setup_torch_training_env
20
+ from TTS.vocoder.datasets.gan_dataset import GANDataset
21
+ from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
22
+ from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
23
+ from TTS.vocoder.utils.generic_utils import (plot_results, setup_discriminator,
24
+ setup_generator)
25
+ from TTS.vocoder.utils.io import save_best_model, save_checkpoint
26
+
27
+ # DISTRIBUTED
28
+ from torch.nn.parallel import DistributedDataParallel as DDP_th
29
+ from torch.utils.data.distributed import DistributedSampler
30
+ from TTS.utils.distribute import init_distributed
31
+
32
+ use_cuda, num_gpus = setup_torch_training_env(True, True)
33
+
34
+
35
+ def setup_loader(ap, is_val=False, verbose=False):
36
+ if is_val and not c.run_eval:
37
+ loader = None
38
+ else:
39
+ dataset = GANDataset(ap=ap,
40
+ items=eval_data if is_val else train_data,
41
+ seq_len=c.seq_len,
42
+ hop_len=ap.hop_length,
43
+ pad_short=c.pad_short,
44
+ conv_pad=c.conv_pad,
45
+ is_training=not is_val,
46
+ return_segments=not is_val,
47
+ use_noise_augment=c.use_noise_augment,
48
+ use_cache=c.use_cache,
49
+ verbose=verbose)
50
+ dataset.shuffle_mapping()
51
+ sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None
52
+ loader = DataLoader(dataset,
53
+ batch_size=1 if is_val else c.batch_size,
54
+ shuffle=False if num_gpus > 1 else True,
55
+ drop_last=False,
56
+ sampler=sampler,
57
+ num_workers=c.num_val_loader_workers
58
+ if is_val else c.num_loader_workers,
59
+ pin_memory=False)
60
+ return loader
61
+
62
+
63
+ def format_data(data):
64
+ if isinstance(data[0], list):
65
+ # setup input data
66
+ c_G, x_G = data[0]
67
+ c_D, x_D = data[1]
68
+
69
+ # dispatch data to GPU
70
+ if use_cuda:
71
+ c_G = c_G.cuda(non_blocking=True)
72
+ x_G = x_G.cuda(non_blocking=True)
73
+ c_D = c_D.cuda(non_blocking=True)
74
+ x_D = x_D.cuda(non_blocking=True)
75
+
76
+ return c_G, x_G, c_D, x_D
77
+
78
+ # return a whole audio segment
79
+ co, x = data
80
+ if use_cuda:
81
+ co = co.cuda(non_blocking=True)
82
+ x = x.cuda(non_blocking=True)
83
+ return co, x, None, None
84
+
85
+
86
+ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
87
+ scheduler_G, scheduler_D, ap, global_step, epoch):
88
+ data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
89
+ model_G.train()
90
+ model_D.train()
91
+ epoch_time = 0
92
+ keep_avg = KeepAverage()
93
+ if use_cuda:
94
+ batch_n_iter = int(
95
+ len(data_loader.dataset) / (c.batch_size * num_gpus))
96
+ else:
97
+ batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
98
+ end_time = time.time()
99
+ c_logger.print_train_start()
100
+ for num_iter, data in enumerate(data_loader):
101
+ start_time = time.time()
102
+
103
+ # format data
104
+ c_G, y_G, c_D, y_D = format_data(data)
105
+ loader_time = time.time() - end_time
106
+
107
+ global_step += 1
108
+
109
+ ##############################
110
+ # GENERATOR
111
+ ##############################
112
+
113
+ # generator pass
114
+ y_hat = model_G(c_G)
115
+ y_hat_sub = None
116
+ y_G_sub = None
117
+ y_hat_vis = y_hat # for visualization
118
+
119
+ # PQMF formatting
120
+ if y_hat.shape[1] > 1:
121
+ y_hat_sub = y_hat
122
+ y_hat = model_G.pqmf_synthesis(y_hat)
123
+ y_hat_vis = y_hat
124
+ y_G_sub = model_G.pqmf_analysis(y_G)
125
+
126
+ scores_fake, feats_fake, feats_real = None, None, None
127
+ if global_step > c.steps_to_start_discriminator:
128
+
129
+ # run D with or without cond. features
130
+ if len(signature(model_D.forward).parameters) == 2:
131
+ D_out_fake = model_D(y_hat, c_G)
132
+ else:
133
+ D_out_fake = model_D(y_hat)
134
+ D_out_real = None
135
+
136
+ if c.use_feat_match_loss:
137
+ with torch.no_grad():
138
+ D_out_real = model_D(y_G)
139
+
140
+ # format D outputs
141
+ if isinstance(D_out_fake, tuple):
142
+ scores_fake, feats_fake = D_out_fake
143
+ if D_out_real is None:
144
+ feats_real = None
145
+ else:
146
+ _, feats_real = D_out_real
147
+ else:
148
+ scores_fake = D_out_fake
149
+
150
+ # compute losses
151
+ loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake,
152
+ feats_real, y_hat_sub, y_G_sub)
153
+ loss_G = loss_G_dict['G_loss']
154
+
155
+ # optimizer generator
156
+ optimizer_G.zero_grad()
157
+ loss_G.backward()
158
+ if c.gen_clip_grad > 0:
159
+ torch.nn.utils.clip_grad_norm_(model_G.parameters(),
160
+ c.gen_clip_grad)
161
+ optimizer_G.step()
162
+ if scheduler_G is not None:
163
+ scheduler_G.step()
164
+
165
+ loss_dict = dict()
166
+ for key, value in loss_G_dict.items():
167
+ if isinstance(value, int):
168
+ loss_dict[key] = value
169
+ else:
170
+ loss_dict[key] = value.item()
171
+
172
+ ##############################
173
+ # DISCRIMINATOR
174
+ ##############################
175
+ if global_step >= c.steps_to_start_discriminator:
176
+ # discriminator pass
177
+ with torch.no_grad():
178
+ y_hat = model_G(c_D)
179
+
180
+ # PQMF formatting
181
+ if y_hat.shape[1] > 1:
182
+ y_hat = model_G.pqmf_synthesis(y_hat)
183
+
184
+ # run D with or without cond. features
185
+ if len(signature(model_D.forward).parameters) == 2:
186
+ D_out_fake = model_D(y_hat.detach(), c_D)
187
+ D_out_real = model_D(y_D, c_D)
188
+ else:
189
+ D_out_fake = model_D(y_hat.detach())
190
+ D_out_real = model_D(y_D)
191
+
192
+ # format D outputs
193
+ if isinstance(D_out_fake, tuple):
194
+ scores_fake, feats_fake = D_out_fake
195
+ if D_out_real is None:
196
+ scores_real, feats_real = None, None
197
+ else:
198
+ scores_real, feats_real = D_out_real
199
+ else:
200
+ scores_fake = D_out_fake
201
+ scores_real = D_out_real
202
+
203
+ # compute losses
204
+ loss_D_dict = criterion_D(scores_fake, scores_real)
205
+ loss_D = loss_D_dict['D_loss']
206
+
207
+ # optimizer discriminator
208
+ optimizer_D.zero_grad()
209
+ loss_D.backward()
210
+ if c.disc_clip_grad > 0:
211
+ torch.nn.utils.clip_grad_norm_(model_D.parameters(),
212
+ c.disc_clip_grad)
213
+ optimizer_D.step()
214
+ if scheduler_D is not None:
215
+ scheduler_D.step()
216
+
217
+ for key, value in loss_D_dict.items():
218
+ if isinstance(value, (int, float)):
219
+ loss_dict[key] = value
220
+ else:
221
+ loss_dict[key] = value.item()
222
+
223
+ step_time = time.time() - start_time
224
+ epoch_time += step_time
225
+
226
+ # get current learning rates
227
+ current_lr_G = list(optimizer_G.param_groups)[0]['lr']
228
+ current_lr_D = list(optimizer_D.param_groups)[0]['lr']
229
+
230
+ # update avg stats
231
+ update_train_values = dict()
232
+ for key, value in loss_dict.items():
233
+ update_train_values['avg_' + key] = value
234
+ update_train_values['avg_loader_time'] = loader_time
235
+ update_train_values['avg_step_time'] = step_time
236
+ keep_avg.update_values(update_train_values)
237
+
238
+ # print training stats
239
+ if global_step % c.print_step == 0:
240
+ log_dict = {
241
+ 'step_time': [step_time, 2],
242
+ 'loader_time': [loader_time, 4],
243
+ "current_lr_G": current_lr_G,
244
+ "current_lr_D": current_lr_D
245
+ }
246
+ c_logger.print_train_step(batch_n_iter, num_iter, global_step,
247
+ log_dict, loss_dict, keep_avg.avg_values)
248
+
249
+ if args.rank == 0:
250
+ # plot step stats
251
+ if global_step % 10 == 0:
252
+ iter_stats = {
253
+ "lr_G": current_lr_G,
254
+ "lr_D": current_lr_D,
255
+ "step_time": step_time
256
+ }
257
+ iter_stats.update(loss_dict)
258
+ tb_logger.tb_train_iter_stats(global_step, iter_stats)
259
+
260
+ # save checkpoint
261
+ if global_step % c.save_step == 0:
262
+ if c.checkpoint:
263
+ # save model
264
+ save_checkpoint(model_G,
265
+ optimizer_G,
266
+ scheduler_G,
267
+ model_D,
268
+ optimizer_D,
269
+ scheduler_D,
270
+ global_step,
271
+ epoch,
272
+ OUT_PATH,
273
+ model_losses=loss_dict)
274
+
275
+ # compute spectrograms
276
+ figures = plot_results(y_hat_vis, y_G, ap, global_step,
277
+ 'train')
278
+ tb_logger.tb_train_figures(global_step, figures)
279
+
280
+ # Sample audio
281
+ sample_voice = y_hat_vis[0].squeeze(0).detach().cpu().numpy()
282
+ tb_logger.tb_train_audios(global_step,
283
+ {'train/audio': sample_voice},
284
+ c.audio["sample_rate"])
285
+ end_time = time.time()
286
+
287
+ # print epoch stats
288
+ c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
289
+
290
+ # Plot Training Epoch Stats
291
+ epoch_stats = {"epoch_time": epoch_time}
292
+ epoch_stats.update(keep_avg.avg_values)
293
+ if args.rank == 0:
294
+ tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
295
+ # TODO: plot model stats
296
+ # if c.tb_model_param_stats:
297
+ # tb_logger.tb_model_weights(model, global_step)
298
+ return keep_avg.avg_values, global_step
299
+
300
+
301
+ @torch.no_grad()
302
+ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch):
303
+ data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0))
304
+ model_G.eval()
305
+ model_D.eval()
306
+ epoch_time = 0
307
+ keep_avg = KeepAverage()
308
+ end_time = time.time()
309
+ c_logger.print_eval_start()
310
+ for num_iter, data in enumerate(data_loader):
311
+ start_time = time.time()
312
+
313
+ # format data
314
+ c_G, y_G, _, _ = format_data(data)
315
+ loader_time = time.time() - end_time
316
+
317
+ global_step += 1
318
+
319
+ ##############################
320
+ # GENERATOR
321
+ ##############################
322
+
323
+ # generator pass
324
+ y_hat = model_G(c_G)
325
+ y_hat_sub = None
326
+ y_G_sub = None
327
+
328
+ # PQMF formatting
329
+ if y_hat.shape[1] > 1:
330
+ y_hat_sub = y_hat
331
+ y_hat = model_G.pqmf_synthesis(y_hat)
332
+ y_G_sub = model_G.pqmf_analysis(y_G)
333
+
334
+ scores_fake, feats_fake, feats_real = None, None, None
335
+ if global_step > c.steps_to_start_discriminator:
336
+
337
+ if len(signature(model_D.forward).parameters) == 2:
338
+ D_out_fake = model_D(y_hat, c_G)
339
+ else:
340
+ D_out_fake = model_D(y_hat)
341
+ D_out_real = None
342
+
343
+ if c.use_feat_match_loss:
344
+ with torch.no_grad():
345
+ D_out_real = model_D(y_G)
346
+
347
+ # format D outputs
348
+ if isinstance(D_out_fake, tuple):
349
+ scores_fake, feats_fake = D_out_fake
350
+ if D_out_real is None:
351
+ feats_real = None
352
+ else:
353
+ _, feats_real = D_out_real
354
+ else:
355
+ scores_fake = D_out_fake
356
+ feats_fake, feats_real = None, None
357
+
358
+ # compute losses
359
+ loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake,
360
+ feats_real, y_hat_sub, y_G_sub)
361
+
362
+ loss_dict = dict()
363
+ for key, value in loss_G_dict.items():
364
+ if isinstance(value, (int, float)):
365
+ loss_dict[key] = value
366
+ else:
367
+ loss_dict[key] = value.item()
368
+
369
+ ##############################
370
+ # DISCRIMINATOR
371
+ ##############################
372
+
373
+ if global_step >= c.steps_to_start_discriminator:
374
+ # discriminator pass
375
+ with torch.no_grad():
376
+ y_hat = model_G(c_G)
377
+
378
+ # PQMF formatting
379
+ if y_hat.shape[1] > 1:
380
+ y_hat = model_G.pqmf_synthesis(y_hat)
381
+
382
+ # run D with or without cond. features
383
+ if len(signature(model_D.forward).parameters) == 2:
384
+ D_out_fake = model_D(y_hat.detach(), c_G)
385
+ D_out_real = model_D(y_G, c_G)
386
+ else:
387
+ D_out_fake = model_D(y_hat.detach())
388
+ D_out_real = model_D(y_G)
389
+
390
+ # format D outputs
391
+ if isinstance(D_out_fake, tuple):
392
+ scores_fake, feats_fake = D_out_fake
393
+ if D_out_real is None:
394
+ scores_real, feats_real = None, None
395
+ else:
396
+ scores_real, feats_real = D_out_real
397
+ else:
398
+ scores_fake = D_out_fake
399
+ scores_real = D_out_real
400
+
401
+ # compute losses
402
+ loss_D_dict = criterion_D(scores_fake, scores_real)
403
+
404
+ for key, value in loss_D_dict.items():
405
+ if isinstance(value, (int, float)):
406
+ loss_dict[key] = value
407
+ else:
408
+ loss_dict[key] = value.item()
409
+
410
+ step_time = time.time() - start_time
411
+ epoch_time += step_time
412
+
413
+ # update avg stats
414
+ update_eval_values = dict()
415
+ for key, value in loss_dict.items():
416
+ update_eval_values['avg_' + key] = value
417
+ update_eval_values['avg_loader_time'] = loader_time
418
+ update_eval_values['avg_step_time'] = step_time
419
+ keep_avg.update_values(update_eval_values)
420
+
421
+ # print eval stats
422
+ if c.print_eval:
423
+ c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
424
+
425
+ if args.rank == 0:
426
+ # compute spectrograms
427
+ figures = plot_results(y_hat, y_G, ap, global_step, 'eval')
428
+ tb_logger.tb_eval_figures(global_step, figures)
429
+
430
+ # Sample audio
431
+ sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
432
+ tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice},
433
+ c.audio["sample_rate"])
434
+
435
+ tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
436
+
437
+ # synthesize a full voice
438
+ data_loader.return_segments = False
439
+
440
+ return keep_avg.avg_values
441
+
442
+
443
+ # FIXME: move args definition/parsing inside of main?
444
+ def main(args): # pylint: disable=redefined-outer-name
445
+ # pylint: disable=global-variable-undefined
446
+ global train_data, eval_data
447
+ print(f" > Loading wavs from: {c.data_path}")
448
+ if c.feature_path is not None:
449
+ print(f" > Loading features from: {c.feature_path}")
450
+ eval_data, train_data = load_wav_feat_data(
451
+ c.data_path, c.feature_path, c.eval_split_size)
452
+ else:
453
+ eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)
454
+
455
+ # setup audio processor
456
+ ap = AudioProcessor(**c.audio)
457
+
458
+ # DISTRUBUTED
459
+ if num_gpus > 1:
460
+ init_distributed(args.rank, num_gpus, args.group_id,
461
+ c.distributed["backend"], c.distributed["url"])
462
+
463
+ # setup models
464
+ model_gen = setup_generator(c)
465
+ model_disc = setup_discriminator(c)
466
+
467
+ # setup optimizers
468
+ optimizer_gen = RAdam(model_gen.parameters(), lr=c.lr_gen, weight_decay=0)
469
+ optimizer_disc = RAdam(model_disc.parameters(),
470
+ lr=c.lr_disc,
471
+ weight_decay=0)
472
+
473
+ # schedulers
474
+ scheduler_gen = None
475
+ scheduler_disc = None
476
+ if 'lr_scheduler_gen' in c:
477
+ scheduler_gen = getattr(torch.optim.lr_scheduler, c.lr_scheduler_gen)
478
+ scheduler_gen = scheduler_gen(
479
+ optimizer_gen, **c.lr_scheduler_gen_params)
480
+ if 'lr_scheduler_disc' in c:
481
+ scheduler_disc = getattr(torch.optim.lr_scheduler, c.lr_scheduler_disc)
482
+ scheduler_disc = scheduler_disc(
483
+ optimizer_disc, **c.lr_scheduler_disc_params)
484
+
485
+ # setup criterion
486
+ criterion_gen = GeneratorLoss(c)
487
+ criterion_disc = DiscriminatorLoss(c)
488
+
489
+ if args.restore_path:
490
+ checkpoint = torch.load(args.restore_path, map_location='cpu')
491
+ try:
492
+ print(" > Restoring Generator Model...")
493
+ model_gen.load_state_dict(checkpoint['model'])
494
+ print(" > Restoring Generator Optimizer...")
495
+ optimizer_gen.load_state_dict(checkpoint['optimizer'])
496
+ print(" > Restoring Discriminator Model...")
497
+ model_disc.load_state_dict(checkpoint['model_disc'])
498
+ print(" > Restoring Discriminator Optimizer...")
499
+ optimizer_disc.load_state_dict(checkpoint['optimizer_disc'])
500
+ if 'scheduler' in checkpoint:
501
+ print(" > Restoring Generator LR Scheduler...")
502
+ scheduler_gen.load_state_dict(checkpoint['scheduler'])
503
+ # NOTE: Not sure if necessary
504
+ scheduler_gen.optimizer = optimizer_gen
505
+ if 'scheduler_disc' in checkpoint:
506
+ print(" > Restoring Discriminator LR Scheduler...")
507
+ scheduler_disc.load_state_dict(checkpoint['scheduler_disc'])
508
+ scheduler_disc.optimizer = optimizer_disc
509
+ except RuntimeError:
510
+ # retore only matching layers.
511
+ print(" > Partial model initialization...")
512
+ model_dict = model_gen.state_dict()
513
+ model_dict = set_init_dict(model_dict, checkpoint['model'], c)
514
+ model_gen.load_state_dict(model_dict)
515
+
516
+ model_dict = model_disc.state_dict()
517
+ model_dict = set_init_dict(model_dict, checkpoint['model_disc'], c)
518
+ model_disc.load_state_dict(model_dict)
519
+ del model_dict
520
+
521
+ # reset lr if not countinuining training.
522
+ for group in optimizer_gen.param_groups:
523
+ group['lr'] = c.lr_gen
524
+
525
+ for group in optimizer_disc.param_groups:
526
+ group['lr'] = c.lr_disc
527
+
528
+ print(" > Model restored from step %d" % checkpoint['step'],
529
+ flush=True)
530
+ args.restore_step = checkpoint['step']
531
+ else:
532
+ args.restore_step = 0
533
+
534
+ if use_cuda:
535
+ model_gen.cuda()
536
+ criterion_gen.cuda()
537
+ model_disc.cuda()
538
+ criterion_disc.cuda()
539
+
540
+ # DISTRUBUTED
541
+ if num_gpus > 1:
542
+ model_gen = DDP_th(model_gen, device_ids=[args.rank])
543
+ model_disc = DDP_th(model_disc, device_ids=[args.rank])
544
+
545
+ num_params = count_parameters(model_gen)
546
+ print(" > Generator has {} parameters".format(num_params), flush=True)
547
+ num_params = count_parameters(model_disc)
548
+ print(" > Discriminator has {} parameters".format(num_params), flush=True)
549
+
550
+ if 'best_loss' not in locals():
551
+ best_loss = float('inf')
552
+
553
+ global_step = args.restore_step
554
+ for epoch in range(0, c.epochs):
555
+ c_logger.print_epoch_start(epoch, c.epochs)
556
+ _, global_step = train(model_gen, criterion_gen, optimizer_gen,
557
+ model_disc, criterion_disc, optimizer_disc,
558
+ scheduler_gen, scheduler_disc, ap, global_step,
559
+ epoch)
560
+ eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc, criterion_disc, ap,
561
+ global_step, epoch)
562
+ c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
563
+ target_loss = eval_avg_loss_dict[c.target_loss]
564
+ best_loss = save_best_model(target_loss,
565
+ best_loss,
566
+ model_gen,
567
+ optimizer_gen,
568
+ scheduler_gen,
569
+ model_disc,
570
+ optimizer_disc,
571
+ scheduler_disc,
572
+ global_step,
573
+ epoch,
574
+ OUT_PATH,
575
+ model_losses=eval_avg_loss_dict)
576
+
577
+
578
+ if __name__ == '__main__':
579
+ parser = argparse.ArgumentParser()
580
+ parser.add_argument(
581
+ '--continue_path',
582
+ type=str,
583
+ help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
584
+ default='',
585
+ required='--config_path' not in sys.argv)
586
+ parser.add_argument(
587
+ '--restore_path',
588
+ type=str,
589
+ help='Model file to be restored. Use to finetune a model.',
590
+ default='')
591
+ parser.add_argument('--config_path',
592
+ type=str,
593
+ help='Path to config file for training.',
594
+ required='--continue_path' not in sys.argv)
595
+ parser.add_argument('--debug',
596
+ type=bool,
597
+ default=False,
598
+ help='Do not verify commit integrity to run training.')
599
+
600
+ # DISTRUBUTED
601
+ parser.add_argument(
602
+ '--rank',
603
+ type=int,
604
+ default=0,
605
+ help='DISTRIBUTED: process rank for distributed training.')
606
+ parser.add_argument('--group_id',
607
+ type=str,
608
+ default="",
609
+ help='DISTRIBUTED: process group id.')
610
+ args = parser.parse_args()
611
+
612
+ if args.continue_path != '':
613
+ args.output_path = args.continue_path
614
+ args.config_path = os.path.join(args.continue_path, 'config.json')
615
+ list_of_files = glob.glob(
616
+ args.continue_path +
617
+ "/*.pth.tar") # * means all if need specific format then *.csv
618
+ latest_model_file = max(list_of_files, key=os.path.getctime)
619
+ args.restore_path = latest_model_file
620
+ print(f" > Training continues for {args.restore_path}")
621
+
622
+ # setup output paths and read configs
623
+ c = load_config(args.config_path)
624
+ # check_config(c)
625
+ _ = os.path.dirname(os.path.realpath(__file__))
626
+
627
+ OUT_PATH = args.continue_path
628
+ if args.continue_path == '':
629
+ OUT_PATH = create_experiment_folder(c.output_path, c.run_name,
630
+ args.debug)
631
+
632
+ AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
633
+
634
+ c_logger = ConsoleLogger()
635
+
636
+ if args.rank == 0:
637
+ os.makedirs(AUDIO_PATH, exist_ok=True)
638
+ new_fields = {}
639
+ if args.restore_path:
640
+ new_fields["restore_path"] = args.restore_path
641
+ new_fields["github_branch"] = get_git_branch()
642
+ copy_model_files(c, args.config_path,
643
+ OUT_PATH, new_fields)
644
+ os.chmod(AUDIO_PATH, 0o775)
645
+ os.chmod(OUT_PATH, 0o775)
646
+
647
+ LOG_DIR = OUT_PATH
648
+ tb_logger = TensorboardLogger(LOG_DIR, model_name='VOCODER')
649
+
650
+ # write model desc to tensorboard
651
+ tb_logger.tb_add_text('model-description', c['run_description'], 0)
652
+
653
+ try:
654
+ main(args)
655
+ except KeyboardInterrupt:
656
+ remove_experiment_folder(OUT_PATH)
657
+ try:
658
+ sys.exit(0)
659
+ except SystemExit:
660
+ os._exit(0) # pylint: disable=protected-access
661
+ except Exception: # pylint: disable=broad-except
662
+ remove_experiment_folder(OUT_PATH)
663
+ traceback.print_exc()
664
+ sys.exit(1)
TTS/bin/train_vocoder_wavegrad.py ADDED
@@ -0,0 +1,511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import os
4
+ import sys
5
+ import time
6
+ import traceback
7
+ import numpy as np
8
+
9
+ import torch
10
+ # DISTRIBUTED
11
+ from torch.nn.parallel import DistributedDataParallel as DDP_th
12
+ from torch.optim import Adam
13
+ from torch.utils.data import DataLoader
14
+ from torch.utils.data.distributed import DistributedSampler
15
+ from TTS.utils.audio import AudioProcessor
16
+ from TTS.utils.console_logger import ConsoleLogger
17
+ from TTS.utils.distribute import init_distributed
18
+ from TTS.utils.generic_utils import (KeepAverage, count_parameters,
19
+ create_experiment_folder, get_git_branch,
20
+ remove_experiment_folder, set_init_dict)
21
+ from TTS.utils.io import copy_model_files, load_config
22
+ from TTS.utils.tensorboard_logger import TensorboardLogger
23
+ from TTS.utils.training import setup_torch_training_env
24
+ from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
25
+ from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
26
+ from TTS.vocoder.utils.generic_utils import plot_results, setup_generator
27
+ from TTS.vocoder.utils.io import save_best_model, save_checkpoint
28
+
29
+ use_cuda, num_gpus = setup_torch_training_env(True, True)
30
+
31
+
32
+ def setup_loader(ap, is_val=False, verbose=False):
33
+ if is_val and not c.run_eval:
34
+ loader = None
35
+ else:
36
+ dataset = WaveGradDataset(ap=ap,
37
+ items=eval_data if is_val else train_data,
38
+ seq_len=c.seq_len,
39
+ hop_len=ap.hop_length,
40
+ pad_short=c.pad_short,
41
+ conv_pad=c.conv_pad,
42
+ is_training=not is_val,
43
+ return_segments=True,
44
+ use_noise_augment=False,
45
+ use_cache=c.use_cache,
46
+ verbose=verbose)
47
+ sampler = DistributedSampler(dataset) if num_gpus > 1 else None
48
+ loader = DataLoader(dataset,
49
+ batch_size=c.batch_size,
50
+ shuffle=num_gpus <= 1,
51
+ drop_last=False,
52
+ sampler=sampler,
53
+ num_workers=c.num_val_loader_workers
54
+ if is_val else c.num_loader_workers,
55
+ pin_memory=False)
56
+
57
+
58
+ return loader
59
+
60
+
61
+ def format_data(data):
62
+ # return a whole audio segment
63
+ m, x = data
64
+ x = x.unsqueeze(1)
65
+ if use_cuda:
66
+ m = m.cuda(non_blocking=True)
67
+ x = x.cuda(non_blocking=True)
68
+ return m, x
69
+
70
+
71
+ def format_test_data(data):
72
+ # return a whole audio segment
73
+ m, x = data
74
+ m = m[None, ...]
75
+ x = x[None, None, ...]
76
+ if use_cuda:
77
+ m = m.cuda(non_blocking=True)
78
+ x = x.cuda(non_blocking=True)
79
+ return m, x
80
+
81
+
82
+ def train(model, criterion, optimizer,
83
+ scheduler, scaler, ap, global_step, epoch):
84
+ data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
85
+ model.train()
86
+ epoch_time = 0
87
+ keep_avg = KeepAverage()
88
+ if use_cuda:
89
+ batch_n_iter = int(
90
+ len(data_loader.dataset) / (c.batch_size * num_gpus))
91
+ else:
92
+ batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
93
+ end_time = time.time()
94
+ c_logger.print_train_start()
95
+ # setup noise schedule
96
+ noise_schedule = c['train_noise_schedule']
97
+ betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'], noise_schedule['num_steps'])
98
+ if hasattr(model, 'module'):
99
+ model.module.compute_noise_level(betas)
100
+ else:
101
+ model.compute_noise_level(betas)
102
+ for num_iter, data in enumerate(data_loader):
103
+ start_time = time.time()
104
+
105
+ # format data
106
+ m, x = format_data(data)
107
+ loader_time = time.time() - end_time
108
+
109
+ global_step += 1
110
+
111
+ with torch.cuda.amp.autocast(enabled=c.mixed_precision):
112
+ # compute noisy input
113
+ if hasattr(model, 'module'):
114
+ noise, x_noisy, noise_scale = model.module.compute_y_n(x)
115
+ else:
116
+ noise, x_noisy, noise_scale = model.compute_y_n(x)
117
+
118
+ # forward pass
119
+ noise_hat = model(x_noisy, m, noise_scale)
120
+
121
+ # compute losses
122
+ loss = criterion(noise, noise_hat)
123
+ loss_wavegrad_dict = {'wavegrad_loss':loss}
124
+
125
+ # check nan loss
126
+ if torch.isnan(loss).any():
127
+ raise RuntimeError(f'Detected NaN loss at step {global_step}.')
128
+
129
+ optimizer.zero_grad()
130
+
131
+ # backward pass with loss scaling
132
+ if c.mixed_precision:
133
+ scaler.scale(loss).backward()
134
+ scaler.unscale_(optimizer)
135
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
136
+ c.clip_grad)
137
+ scaler.step(optimizer)
138
+ scaler.update()
139
+ else:
140
+ loss.backward()
141
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
142
+ c.clip_grad)
143
+ optimizer.step()
144
+
145
+ # schedule update
146
+ if scheduler is not None:
147
+ scheduler.step()
148
+
149
+ # disconnect loss values
150
+ loss_dict = dict()
151
+ for key, value in loss_wavegrad_dict.items():
152
+ if isinstance(value, int):
153
+ loss_dict[key] = value
154
+ else:
155
+ loss_dict[key] = value.item()
156
+
157
+ # epoch/step timing
158
+ step_time = time.time() - start_time
159
+ epoch_time += step_time
160
+
161
+ # get current learning rates
162
+ current_lr = list(optimizer.param_groups)[0]['lr']
163
+
164
+ # update avg stats
165
+ update_train_values = dict()
166
+ for key, value in loss_dict.items():
167
+ update_train_values['avg_' + key] = value
168
+ update_train_values['avg_loader_time'] = loader_time
169
+ update_train_values['avg_step_time'] = step_time
170
+ keep_avg.update_values(update_train_values)
171
+
172
+ # print training stats
173
+ if global_step % c.print_step == 0:
174
+ log_dict = {
175
+ 'step_time': [step_time, 2],
176
+ 'loader_time': [loader_time, 4],
177
+ "current_lr": current_lr,
178
+ "grad_norm": grad_norm.item()
179
+ }
180
+ c_logger.print_train_step(batch_n_iter, num_iter, global_step,
181
+ log_dict, loss_dict, keep_avg.avg_values)
182
+
183
+ if args.rank == 0:
184
+ # plot step stats
185
+ if global_step % 10 == 0:
186
+ iter_stats = {
187
+ "lr": current_lr,
188
+ "grad_norm": grad_norm.item(),
189
+ "step_time": step_time
190
+ }
191
+ iter_stats.update(loss_dict)
192
+ tb_logger.tb_train_iter_stats(global_step, iter_stats)
193
+
194
+ # save checkpoint
195
+ if global_step % c.save_step == 0:
196
+ if c.checkpoint:
197
+ # save model
198
+ save_checkpoint(model,
199
+ optimizer,
200
+ scheduler,
201
+ None,
202
+ None,
203
+ None,
204
+ global_step,
205
+ epoch,
206
+ OUT_PATH,
207
+ model_losses=loss_dict,
208
+ scaler=scaler.state_dict() if c.mixed_precision else None)
209
+
210
+ end_time = time.time()
211
+
212
+ # print epoch stats
213
+ c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
214
+
215
+ # Plot Training Epoch Stats
216
+ epoch_stats = {"epoch_time": epoch_time}
217
+ epoch_stats.update(keep_avg.avg_values)
218
+ if args.rank == 0:
219
+ tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
220
+ # TODO: plot model stats
221
+ if c.tb_model_param_stats and args.rank == 0:
222
+ tb_logger.tb_model_weights(model, global_step)
223
+ return keep_avg.avg_values, global_step
224
+
225
+
226
+ @torch.no_grad()
227
+ def evaluate(model, criterion, ap, global_step, epoch):
228
+ data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0))
229
+ model.eval()
230
+ epoch_time = 0
231
+ keep_avg = KeepAverage()
232
+ end_time = time.time()
233
+ c_logger.print_eval_start()
234
+ for num_iter, data in enumerate(data_loader):
235
+ start_time = time.time()
236
+
237
+ # format data
238
+ m, x = format_data(data)
239
+ loader_time = time.time() - end_time
240
+
241
+ global_step += 1
242
+
243
+ # compute noisy input
244
+ if hasattr(model, 'module'):
245
+ noise, x_noisy, noise_scale = model.module.compute_y_n(x)
246
+ else:
247
+ noise, x_noisy, noise_scale = model.compute_y_n(x)
248
+
249
+
250
+ # forward pass
251
+ noise_hat = model(x_noisy, m, noise_scale)
252
+
253
+ # compute losses
254
+ loss = criterion(noise, noise_hat)
255
+ loss_wavegrad_dict = {'wavegrad_loss':loss}
256
+
257
+
258
+ loss_dict = dict()
259
+ for key, value in loss_wavegrad_dict.items():
260
+ if isinstance(value, (int, float)):
261
+ loss_dict[key] = value
262
+ else:
263
+ loss_dict[key] = value.item()
264
+
265
+ step_time = time.time() - start_time
266
+ epoch_time += step_time
267
+
268
+ # update avg stats
269
+ update_eval_values = dict()
270
+ for key, value in loss_dict.items():
271
+ update_eval_values['avg_' + key] = value
272
+ update_eval_values['avg_loader_time'] = loader_time
273
+ update_eval_values['avg_step_time'] = step_time
274
+ keep_avg.update_values(update_eval_values)
275
+
276
+ # print eval stats
277
+ if c.print_eval:
278
+ c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
279
+
280
+ if args.rank == 0:
281
+ data_loader.dataset.return_segments = False
282
+ samples = data_loader.dataset.load_test_samples(1)
283
+ m, x = format_test_data(samples[0])
284
+
285
+ # setup noise schedule and inference
286
+ noise_schedule = c['test_noise_schedule']
287
+ betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'], noise_schedule['num_steps'])
288
+ if hasattr(model, 'module'):
289
+ model.module.compute_noise_level(betas)
290
+ # compute voice
291
+ x_pred = model.module.inference(m)
292
+ else:
293
+ model.compute_noise_level(betas)
294
+ # compute voice
295
+ x_pred = model.inference(m)
296
+
297
+ # compute spectrograms
298
+ figures = plot_results(x_pred, x, ap, global_step, 'eval')
299
+ tb_logger.tb_eval_figures(global_step, figures)
300
+
301
+ # Sample audio
302
+ sample_voice = x_pred[0].squeeze(0).detach().cpu().numpy()
303
+ tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice},
304
+ c.audio["sample_rate"])
305
+
306
+ tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
307
+ data_loader.dataset.return_segments = True
308
+
309
+ return keep_avg.avg_values
310
+
311
+
312
+ def main(args): # pylint: disable=redefined-outer-name
313
+ # pylint: disable=global-variable-undefined
314
+ global train_data, eval_data
315
+ print(f" > Loading wavs from: {c.data_path}")
316
+ if c.feature_path is not None:
317
+ print(f" > Loading features from: {c.feature_path}")
318
+ eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size)
319
+ else:
320
+ eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)
321
+
322
+ # setup audio processor
323
+ ap = AudioProcessor(**c.audio)
324
+
325
+ # DISTRUBUTED
326
+ if num_gpus > 1:
327
+ init_distributed(args.rank, num_gpus, args.group_id,
328
+ c.distributed["backend"], c.distributed["url"])
329
+
330
+ # setup models
331
+ model = setup_generator(c)
332
+
333
+ # scaler for mixed_precision
334
+ scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
335
+
336
+ # setup optimizers
337
+ optimizer = Adam(model.parameters(), lr=c.lr, weight_decay=0)
338
+
339
+ # schedulers
340
+ scheduler = None
341
+ if 'lr_scheduler' in c:
342
+ scheduler = getattr(torch.optim.lr_scheduler, c.lr_scheduler)
343
+ scheduler = scheduler(optimizer, **c.lr_scheduler_params)
344
+
345
+ # setup criterion
346
+ criterion = torch.nn.L1Loss().cuda()
347
+
348
+ if args.restore_path:
349
+ checkpoint = torch.load(args.restore_path, map_location='cpu')
350
+ try:
351
+ print(" > Restoring Model...")
352
+ model.load_state_dict(checkpoint['model'])
353
+ print(" > Restoring Optimizer...")
354
+ optimizer.load_state_dict(checkpoint['optimizer'])
355
+ if 'scheduler' in checkpoint:
356
+ print(" > Restoring LR Scheduler...")
357
+ scheduler.load_state_dict(checkpoint['scheduler'])
358
+ # NOTE: Not sure if necessary
359
+ scheduler.optimizer = optimizer
360
+ if "scaler" in checkpoint and c.mixed_precision:
361
+ print(" > Restoring AMP Scaler...")
362
+ scaler.load_state_dict(checkpoint["scaler"])
363
+ except RuntimeError:
364
+ # retore only matching layers.
365
+ print(" > Partial model initialization...")
366
+ model_dict = model.state_dict()
367
+ model_dict = set_init_dict(model_dict, checkpoint['model'], c)
368
+ model.load_state_dict(model_dict)
369
+ del model_dict
370
+
371
+ # reset lr if not countinuining training.
372
+ for group in optimizer.param_groups:
373
+ group['lr'] = c.lr
374
+
375
+ print(" > Model restored from step %d" % checkpoint['step'],
376
+ flush=True)
377
+ args.restore_step = checkpoint['step']
378
+ else:
379
+ args.restore_step = 0
380
+
381
+ if use_cuda:
382
+ model.cuda()
383
+ criterion.cuda()
384
+
385
+ # DISTRUBUTED
386
+ if num_gpus > 1:
387
+ model = DDP_th(model, device_ids=[args.rank])
388
+
389
+ num_params = count_parameters(model)
390
+ print(" > WaveGrad has {} parameters".format(num_params), flush=True)
391
+
392
+ if 'best_loss' not in locals():
393
+ best_loss = float('inf')
394
+
395
+ global_step = args.restore_step
396
+ for epoch in range(0, c.epochs):
397
+ c_logger.print_epoch_start(epoch, c.epochs)
398
+ _, global_step = train(model, criterion, optimizer,
399
+ scheduler, scaler, ap, global_step,
400
+ epoch)
401
+ eval_avg_loss_dict = evaluate(model, criterion, ap,
402
+ global_step, epoch)
403
+ c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
404
+ target_loss = eval_avg_loss_dict[c.target_loss]
405
+ best_loss = save_best_model(target_loss,
406
+ best_loss,
407
+ model,
408
+ optimizer,
409
+ scheduler,
410
+ None,
411
+ None,
412
+ None,
413
+ global_step,
414
+ epoch,
415
+ OUT_PATH,
416
+ model_losses=eval_avg_loss_dict,
417
+ scaler=scaler.state_dict() if c.mixed_precision else None)
418
+
419
+
420
+ if __name__ == '__main__':
421
+ parser = argparse.ArgumentParser()
422
+ parser.add_argument(
423
+ '--continue_path',
424
+ type=str,
425
+ help=
426
+ 'Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
427
+ default='',
428
+ required='--config_path' not in sys.argv)
429
+ parser.add_argument(
430
+ '--restore_path',
431
+ type=str,
432
+ help='Model file to be restored. Use to finetune a model.',
433
+ default='')
434
+ parser.add_argument('--config_path',
435
+ type=str,
436
+ help='Path to config file for training.',
437
+ required='--continue_path' not in sys.argv)
438
+ parser.add_argument('--debug',
439
+ type=bool,
440
+ default=False,
441
+ help='Do not verify commit integrity to run training.')
442
+
443
+ # DISTRUBUTED
444
+ parser.add_argument(
445
+ '--rank',
446
+ type=int,
447
+ default=0,
448
+ help='DISTRIBUTED: process rank for distributed training.')
449
+ parser.add_argument('--group_id',
450
+ type=str,
451
+ default="",
452
+ help='DISTRIBUTED: process group id.')
453
+ args = parser.parse_args()
454
+
455
+ if args.continue_path != '':
456
+ args.output_path = args.continue_path
457
+ args.config_path = os.path.join(args.continue_path, 'config.json')
458
+ list_of_files = glob.glob(
459
+ args.continue_path +
460
+ "/*.pth.tar") # * means all if need specific format then *.csv
461
+ latest_model_file = max(list_of_files, key=os.path.getctime)
462
+ args.restore_path = latest_model_file
463
+ print(f" > Training continues for {args.restore_path}")
464
+
465
+ # setup output paths and read configs
466
+ c = load_config(args.config_path)
467
+ # check_config(c)
468
+ _ = os.path.dirname(os.path.realpath(__file__))
469
+
470
+ # DISTRIBUTED
471
+ if c.mixed_precision:
472
+ print(" > Mixed precision is enabled")
473
+
474
+ OUT_PATH = args.continue_path
475
+ if args.continue_path == '':
476
+ OUT_PATH = create_experiment_folder(c.output_path, c.run_name,
477
+ args.debug)
478
+
479
+ AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
480
+
481
+ c_logger = ConsoleLogger()
482
+
483
+ if args.rank == 0:
484
+ os.makedirs(AUDIO_PATH, exist_ok=True)
485
+ new_fields = {}
486
+ if args.restore_path:
487
+ new_fields["restore_path"] = args.restore_path
488
+ new_fields["github_branch"] = get_git_branch()
489
+ copy_model_files(c, args.config_path,
490
+ OUT_PATH, new_fields)
491
+ os.chmod(AUDIO_PATH, 0o775)
492
+ os.chmod(OUT_PATH, 0o775)
493
+
494
+ LOG_DIR = OUT_PATH
495
+ tb_logger = TensorboardLogger(LOG_DIR, model_name='VOCODER')
496
+
497
+ # write model desc to tensorboard
498
+ tb_logger.tb_add_text('model-description', c['run_description'], 0)
499
+
500
+ try:
501
+ main(args)
502
+ except KeyboardInterrupt:
503
+ remove_experiment_folder(OUT_PATH)
504
+ try:
505
+ sys.exit(0)
506
+ except SystemExit:
507
+ os._exit(0) # pylint: disable=protected-access
508
+ except Exception: # pylint: disable=broad-except
509
+ remove_experiment_folder(OUT_PATH)
510
+ traceback.print_exc()
511
+ sys.exit(1)
TTS/bin/train_vocoder_wavernn.py ADDED
@@ -0,0 +1,539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ import traceback
5
+ import time
6
+ import glob
7
+ import random
8
+
9
+ import torch
10
+ from torch.utils.data import DataLoader
11
+
12
+ # from torch.utils.data.distributed import DistributedSampler
13
+
14
+ from TTS.tts.utils.visual import plot_spectrogram
15
+ from TTS.utils.audio import AudioProcessor
16
+ from TTS.utils.radam import RAdam
17
+ from TTS.utils.io import copy_model_files, load_config
18
+ from TTS.utils.training import setup_torch_training_env
19
+ from TTS.utils.console_logger import ConsoleLogger
20
+ from TTS.utils.tensorboard_logger import TensorboardLogger
21
+ from TTS.utils.generic_utils import (
22
+ KeepAverage,
23
+ count_parameters,
24
+ create_experiment_folder,
25
+ get_git_branch,
26
+ remove_experiment_folder,
27
+ set_init_dict,
28
+ )
29
+ from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
30
+ from TTS.vocoder.datasets.preprocess import (
31
+ load_wav_data,
32
+ load_wav_feat_data
33
+ )
34
+ from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
35
+ from TTS.vocoder.utils.generic_utils import setup_wavernn
36
+ from TTS.vocoder.utils.io import save_best_model, save_checkpoint
37
+
38
+
39
+ use_cuda, num_gpus = setup_torch_training_env(True, True)
40
+
41
+
42
+ def setup_loader(ap, is_val=False, verbose=False):
43
+ if is_val and not c.run_eval:
44
+ loader = None
45
+ else:
46
+ dataset = WaveRNNDataset(ap=ap,
47
+ items=eval_data if is_val else train_data,
48
+ seq_len=c.seq_len,
49
+ hop_len=ap.hop_length,
50
+ pad=c.padding,
51
+ mode=c.mode,
52
+ mulaw=c.mulaw,
53
+ is_training=not is_val,
54
+ verbose=verbose,
55
+ )
56
+ # sampler = DistributedSampler(dataset) if num_gpus > 1 else None
57
+ loader = DataLoader(dataset,
58
+ shuffle=True,
59
+ collate_fn=dataset.collate,
60
+ batch_size=c.batch_size,
61
+ num_workers=c.num_val_loader_workers
62
+ if is_val
63
+ else c.num_loader_workers,
64
+ pin_memory=True,
65
+ )
66
+ return loader
67
+
68
+
69
+ def format_data(data):
70
+ # setup input data
71
+ x_input = data[0]
72
+ mels = data[1]
73
+ y_coarse = data[2]
74
+
75
+ # dispatch data to GPU
76
+ if use_cuda:
77
+ x_input = x_input.cuda(non_blocking=True)
78
+ mels = mels.cuda(non_blocking=True)
79
+ y_coarse = y_coarse.cuda(non_blocking=True)
80
+
81
+ return x_input, mels, y_coarse
82
+
83
+
84
+ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch):
85
+ # create train loader
86
+ data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
87
+ model.train()
88
+ epoch_time = 0
89
+ keep_avg = KeepAverage()
90
+ if use_cuda:
91
+ batch_n_iter = int(len(data_loader.dataset) /
92
+ (c.batch_size * num_gpus))
93
+ else:
94
+ batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
95
+ end_time = time.time()
96
+ c_logger.print_train_start()
97
+ # train loop
98
+ for num_iter, data in enumerate(data_loader):
99
+ start_time = time.time()
100
+ x_input, mels, y_coarse = format_data(data)
101
+ loader_time = time.time() - end_time
102
+ global_step += 1
103
+
104
+ optimizer.zero_grad()
105
+
106
+ if c.mixed_precision:
107
+ # mixed precision training
108
+ with torch.cuda.amp.autocast():
109
+ y_hat = model(x_input, mels)
110
+ if isinstance(model.mode, int):
111
+ y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
112
+ else:
113
+ y_coarse = y_coarse.float()
114
+ y_coarse = y_coarse.unsqueeze(-1)
115
+ # compute losses
116
+ loss = criterion(y_hat, y_coarse)
117
+ scaler.scale(loss).backward()
118
+ scaler.unscale_(optimizer)
119
+ if c.grad_clip > 0:
120
+ torch.nn.utils.clip_grad_norm_(
121
+ model.parameters(), c.grad_clip)
122
+ scaler.step(optimizer)
123
+ scaler.update()
124
+ else:
125
+ # full precision training
126
+ y_hat = model(x_input, mels)
127
+ if isinstance(model.mode, int):
128
+ y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
129
+ else:
130
+ y_coarse = y_coarse.float()
131
+ y_coarse = y_coarse.unsqueeze(-1)
132
+ # compute losses
133
+ loss = criterion(y_hat, y_coarse)
134
+ if loss.item() is None:
135
+ raise RuntimeError(" [!] None loss. Exiting ...")
136
+ loss.backward()
137
+ if c.grad_clip > 0:
138
+ torch.nn.utils.clip_grad_norm_(
139
+ model.parameters(), c.grad_clip)
140
+ optimizer.step()
141
+
142
+ if scheduler is not None:
143
+ scheduler.step()
144
+
145
+ # get the current learning rate
146
+ cur_lr = list(optimizer.param_groups)[0]["lr"]
147
+
148
+ step_time = time.time() - start_time
149
+ epoch_time += step_time
150
+
151
+ update_train_values = dict()
152
+ loss_dict = dict()
153
+ loss_dict["model_loss"] = loss.item()
154
+ for key, value in loss_dict.items():
155
+ update_train_values["avg_" + key] = value
156
+ update_train_values["avg_loader_time"] = loader_time
157
+ update_train_values["avg_step_time"] = step_time
158
+ keep_avg.update_values(update_train_values)
159
+
160
+ # print training stats
161
+ if global_step % c.print_step == 0:
162
+ log_dict = {"step_time": [step_time, 2],
163
+ "loader_time": [loader_time, 4],
164
+ "current_lr": cur_lr,
165
+ }
166
+ c_logger.print_train_step(batch_n_iter,
167
+ num_iter,
168
+ global_step,
169
+ log_dict,
170
+ loss_dict,
171
+ keep_avg.avg_values,
172
+ )
173
+
174
+ # plot step stats
175
+ if global_step % 10 == 0:
176
+ iter_stats = {"lr": cur_lr, "step_time": step_time}
177
+ iter_stats.update(loss_dict)
178
+ tb_logger.tb_train_iter_stats(global_step, iter_stats)
179
+
180
+ # save checkpoint
181
+ if global_step % c.save_step == 0:
182
+ if c.checkpoint:
183
+ # save model
184
+ save_checkpoint(model,
185
+ optimizer,
186
+ scheduler,
187
+ None,
188
+ None,
189
+ None,
190
+ global_step,
191
+ epoch,
192
+ OUT_PATH,
193
+ model_losses=loss_dict,
194
+ scaler=scaler.state_dict() if c.mixed_precision else None
195
+ )
196
+
197
+ # synthesize a full voice
198
+ rand_idx = random.randrange(0, len(train_data))
199
+ wav_path = train_data[rand_idx] if not isinstance(
200
+ train_data[rand_idx], (tuple, list)) else train_data[rand_idx][0]
201
+ wav = ap.load_wav(wav_path)
202
+ ground_mel = ap.melspectrogram(wav)
203
+ sample_wav = model.generate(ground_mel,
204
+ c.batched,
205
+ c.target_samples,
206
+ c.overlap_samples,
207
+ use_cuda
208
+ )
209
+ predict_mel = ap.melspectrogram(sample_wav)
210
+
211
+ # compute spectrograms
212
+ figures = {"train/ground_truth": plot_spectrogram(ground_mel.T),
213
+ "train/prediction": plot_spectrogram(predict_mel.T)
214
+ }
215
+ tb_logger.tb_train_figures(global_step, figures)
216
+
217
+ # Sample audio
218
+ tb_logger.tb_train_audios(
219
+ global_step, {
220
+ "train/audio": sample_wav}, c.audio["sample_rate"]
221
+ )
222
+ end_time = time.time()
223
+
224
+ # print epoch stats
225
+ c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
226
+
227
+ # Plot Training Epoch Stats
228
+ epoch_stats = {"epoch_time": epoch_time}
229
+ epoch_stats.update(keep_avg.avg_values)
230
+ tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
231
+ # TODO: plot model stats
232
+ # if c.tb_model_param_stats:
233
+ # tb_logger.tb_model_weights(model, global_step)
234
+ return keep_avg.avg_values, global_step
235
+
236
+
237
+ @torch.no_grad()
238
+ def evaluate(model, criterion, ap, global_step, epoch):
239
+ # create train loader
240
+ data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0))
241
+ model.eval()
242
+ epoch_time = 0
243
+ keep_avg = KeepAverage()
244
+ end_time = time.time()
245
+ c_logger.print_eval_start()
246
+ with torch.no_grad():
247
+ for num_iter, data in enumerate(data_loader):
248
+ start_time = time.time()
249
+ # format data
250
+ x_input, mels, y_coarse = format_data(data)
251
+ loader_time = time.time() - end_time
252
+ global_step += 1
253
+
254
+ y_hat = model(x_input, mels)
255
+ if isinstance(model.mode, int):
256
+ y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
257
+ else:
258
+ y_coarse = y_coarse.float()
259
+ y_coarse = y_coarse.unsqueeze(-1)
260
+ loss = criterion(y_hat, y_coarse)
261
+ # Compute avg loss
262
+ # if num_gpus > 1:
263
+ # loss = reduce_tensor(loss.data, num_gpus)
264
+ loss_dict = dict()
265
+ loss_dict["model_loss"] = loss.item()
266
+
267
+ step_time = time.time() - start_time
268
+ epoch_time += step_time
269
+
270
+ # update avg stats
271
+ update_eval_values = dict()
272
+ for key, value in loss_dict.items():
273
+ update_eval_values["avg_" + key] = value
274
+ update_eval_values["avg_loader_time"] = loader_time
275
+ update_eval_values["avg_step_time"] = step_time
276
+ keep_avg.update_values(update_eval_values)
277
+
278
+ # print eval stats
279
+ if c.print_eval:
280
+ c_logger.print_eval_step(
281
+ num_iter, loss_dict, keep_avg.avg_values)
282
+
283
+ if epoch % c.test_every_epochs == 0 and epoch != 0:
284
+ # synthesize a full voice
285
+ rand_idx = random.randrange(0, len(eval_data))
286
+ wav_path = eval_data[rand_idx] if not isinstance(
287
+ eval_data[rand_idx], (tuple, list)) else eval_data[rand_idx][0]
288
+ wav = ap.load_wav(wav_path)
289
+ ground_mel = ap.melspectrogram(wav)
290
+ sample_wav = model.generate(ground_mel,
291
+ c.batched,
292
+ c.target_samples,
293
+ c.overlap_samples,
294
+ use_cuda
295
+ )
296
+ predict_mel = ap.melspectrogram(sample_wav)
297
+
298
+ # Sample audio
299
+ tb_logger.tb_eval_audios(
300
+ global_step, {
301
+ "eval/audio": sample_wav}, c.audio["sample_rate"]
302
+ )
303
+
304
+ # compute spectrograms
305
+ figures = {"eval/ground_truth": plot_spectrogram(ground_mel.T),
306
+ "eval/prediction": plot_spectrogram(predict_mel.T)
307
+ }
308
+ tb_logger.tb_eval_figures(global_step, figures)
309
+
310
+ tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
311
+ return keep_avg.avg_values
312
+
313
+
314
+ # FIXME: move args definition/parsing inside of main?
315
+ def main(args): # pylint: disable=redefined-outer-name
316
+ # pylint: disable=global-variable-undefined
317
+ global train_data, eval_data
318
+
319
+ # setup audio processor
320
+ ap = AudioProcessor(**c.audio)
321
+
322
+ # print(f" > Loading wavs from: {c.data_path}")
323
+ # if c.feature_path is not None:
324
+ # print(f" > Loading features from: {c.feature_path}")
325
+ # eval_data, train_data = load_wav_feat_data(
326
+ # c.data_path, c.feature_path, c.eval_split_size
327
+ # )
328
+ # else:
329
+ # mel_feat_path = os.path.join(OUT_PATH, "mel")
330
+ # feat_data = find_feat_files(mel_feat_path)
331
+ # if feat_data:
332
+ # print(f" > Loading features from: {mel_feat_path}")
333
+ # eval_data, train_data = load_wav_feat_data(
334
+ # c.data_path, mel_feat_path, c.eval_split_size
335
+ # )
336
+ # else:
337
+ # print(" > No feature data found. Preprocessing...")
338
+ # # preprocessing feature data from given wav files
339
+ # preprocess_wav_files(OUT_PATH, CONFIG, ap)
340
+ # eval_data, train_data = load_wav_feat_data(
341
+ # c.data_path, mel_feat_path, c.eval_split_size
342
+ # )
343
+
344
+ print(f" > Loading wavs from: {c.data_path}")
345
+ if c.feature_path is not None:
346
+ print(f" > Loading features from: {c.feature_path}")
347
+ eval_data, train_data = load_wav_feat_data(
348
+ c.data_path, c.feature_path, c.eval_split_size)
349
+ else:
350
+ eval_data, train_data = load_wav_data(
351
+ c.data_path, c.eval_split_size)
352
+ # setup model
353
+ model_wavernn = setup_wavernn(c)
354
+
355
+ # setup amp scaler
356
+ scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
357
+
358
+ # define train functions
359
+ if c.mode == "mold":
360
+ criterion = discretized_mix_logistic_loss
361
+ elif c.mode == "gauss":
362
+ criterion = gaussian_loss
363
+ elif isinstance(c.mode, int):
364
+ criterion = torch.nn.CrossEntropyLoss()
365
+
366
+ if use_cuda:
367
+ model_wavernn.cuda()
368
+ if isinstance(c.mode, int):
369
+ criterion.cuda()
370
+
371
+ optimizer = RAdam(model_wavernn.parameters(), lr=c.lr, weight_decay=0)
372
+
373
+ scheduler = None
374
+ if "lr_scheduler" in c:
375
+ scheduler = getattr(torch.optim.lr_scheduler, c.lr_scheduler)
376
+ scheduler = scheduler(optimizer, **c.lr_scheduler_params)
377
+ # slow start for the first 5 epochs
378
+ # lr_lambda = lambda epoch: min(epoch / c.warmup_steps, 1)
379
+ # scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
380
+
381
+ # restore any checkpoint
382
+ if args.restore_path:
383
+ checkpoint = torch.load(args.restore_path, map_location="cpu")
384
+ try:
385
+ print(" > Restoring Model...")
386
+ model_wavernn.load_state_dict(checkpoint["model"])
387
+ print(" > Restoring Optimizer...")
388
+ optimizer.load_state_dict(checkpoint["optimizer"])
389
+ if "scheduler" in checkpoint:
390
+ print(" > Restoring Generator LR Scheduler...")
391
+ scheduler.load_state_dict(checkpoint["scheduler"])
392
+ scheduler.optimizer = optimizer
393
+ if "scaler" in checkpoint and c.mixed_precision:
394
+ print(" > Restoring AMP Scaler...")
395
+ scaler.load_state_dict(checkpoint["scaler"])
396
+ except RuntimeError:
397
+ # retore only matching layers.
398
+ print(" > Partial model initialization...")
399
+ model_dict = model_wavernn.state_dict()
400
+ model_dict = set_init_dict(model_dict, checkpoint["model"], c)
401
+ model_wavernn.load_state_dict(model_dict)
402
+
403
+ print(" > Model restored from step %d" %
404
+ checkpoint["step"], flush=True)
405
+ args.restore_step = checkpoint["step"]
406
+ else:
407
+ args.restore_step = 0
408
+
409
+ # DISTRIBUTED
410
+ # if num_gpus > 1:
411
+ # model = apply_gradient_allreduce(model)
412
+
413
+ num_parameters = count_parameters(model_wavernn)
414
+ print(" > Model has {} parameters".format(num_parameters), flush=True)
415
+
416
+ if "best_loss" not in locals():
417
+ best_loss = float("inf")
418
+
419
+ global_step = args.restore_step
420
+ for epoch in range(0, c.epochs):
421
+ c_logger.print_epoch_start(epoch, c.epochs)
422
+ _, global_step = train(model_wavernn, optimizer,
423
+ criterion, scheduler, scaler, ap, global_step, epoch)
424
+ eval_avg_loss_dict = evaluate(
425
+ model_wavernn, criterion, ap, global_step, epoch)
426
+ c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
427
+ target_loss = eval_avg_loss_dict["avg_model_loss"]
428
+ best_loss = save_best_model(
429
+ target_loss,
430
+ best_loss,
431
+ model_wavernn,
432
+ optimizer,
433
+ scheduler,
434
+ None,
435
+ None,
436
+ None,
437
+ global_step,
438
+ epoch,
439
+ OUT_PATH,
440
+ model_losses=eval_avg_loss_dict,
441
+ scaler=scaler.state_dict() if c.mixed_precision else None
442
+ )
443
+
444
+
445
+ if __name__ == "__main__":
446
+ parser = argparse.ArgumentParser()
447
+ parser.add_argument(
448
+ "--continue_path",
449
+ type=str,
450
+ help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
451
+ default="",
452
+ required="--config_path" not in sys.argv,
453
+ )
454
+ parser.add_argument(
455
+ "--restore_path",
456
+ type=str,
457
+ help="Model file to be restored. Use to finetune a model.",
458
+ default="",
459
+ )
460
+ parser.add_argument(
461
+ "--config_path",
462
+ type=str,
463
+ help="Path to config file for training.",
464
+ required="--continue_path" not in sys.argv,
465
+ )
466
+ parser.add_argument(
467
+ "--debug",
468
+ type=bool,
469
+ default=False,
470
+ help="Do not verify commit integrity to run training.",
471
+ )
472
+
473
+ # DISTRUBUTED
474
+ parser.add_argument(
475
+ "--rank",
476
+ type=int,
477
+ default=0,
478
+ help="DISTRIBUTED: process rank for distributed training.",
479
+ )
480
+ parser.add_argument(
481
+ "--group_id", type=str, default="", help="DISTRIBUTED: process group id."
482
+ )
483
+ args = parser.parse_args()
484
+
485
+ if args.continue_path != "":
486
+ args.output_path = args.continue_path
487
+ args.config_path = os.path.join(args.continue_path, "config.json")
488
+ list_of_files = glob.glob(
489
+ args.continue_path + "/*.pth.tar"
490
+ ) # * means all if need specific format then *.csv
491
+ latest_model_file = max(list_of_files, key=os.path.getctime)
492
+ args.restore_path = latest_model_file
493
+ print(f" > Training continues for {args.restore_path}")
494
+
495
+ # setup output paths and read configs
496
+ c = load_config(args.config_path)
497
+ # check_config(c)
498
+ _ = os.path.dirname(os.path.realpath(__file__))
499
+
500
+ OUT_PATH = args.continue_path
501
+ if args.continue_path == "":
502
+ OUT_PATH = create_experiment_folder(
503
+ c.output_path, c.run_name, args.debug
504
+ )
505
+
506
+ AUDIO_PATH = os.path.join(OUT_PATH, "test_audios")
507
+
508
+ c_logger = ConsoleLogger()
509
+
510
+ if args.rank == 0:
511
+ os.makedirs(AUDIO_PATH, exist_ok=True)
512
+ new_fields = {}
513
+ if args.restore_path:
514
+ new_fields["restore_path"] = args.restore_path
515
+ new_fields["github_branch"] = get_git_branch()
516
+ copy_model_files(
517
+ c, args.config_path, OUT_PATH, new_fields
518
+ )
519
+ os.chmod(AUDIO_PATH, 0o775)
520
+ os.chmod(OUT_PATH, 0o775)
521
+
522
+ LOG_DIR = OUT_PATH
523
+ tb_logger = TensorboardLogger(LOG_DIR, model_name="VOCODER")
524
+
525
+ # write model desc to tensorboard
526
+ tb_logger.tb_add_text("model-description", c["run_description"], 0)
527
+
528
+ try:
529
+ main(args)
530
+ except KeyboardInterrupt:
531
+ remove_experiment_folder(OUT_PATH)
532
+ try:
533
+ sys.exit(0)
534
+ except SystemExit:
535
+ os._exit(0) # pylint: disable=protected-access
536
+ except Exception: # pylint: disable=broad-except
537
+ remove_experiment_folder(OUT_PATH)
538
+ traceback.print_exc()
539
+ sys.exit(1)
TTS/bin/tune_wavegrad.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Search a good noise schedule for WaveGrad for a given number of inferece iterations"""
2
+ import argparse
3
+ from itertools import product as cartesian_product
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch.utils.data import DataLoader
8
+ from tqdm import tqdm
9
+ from TTS.utils.audio import AudioProcessor
10
+ from TTS.utils.io import load_config
11
+ from TTS.vocoder.datasets.preprocess import load_wav_data
12
+ from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
13
+ from TTS.vocoder.utils.generic_utils import setup_generator
14
+
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument('--model_path', type=str, help='Path to model checkpoint.')
17
+ parser.add_argument('--config_path', type=str, help='Path to model config file.')
18
+ parser.add_argument('--data_path', type=str, help='Path to data directory.')
19
+ parser.add_argument('--output_path', type=str, help='path for output file including file name and extension.')
20
+ parser.add_argument('--num_iter', type=int, help='Number of model inference iterations that you like to optimize noise schedule for.')
21
+ parser.add_argument('--use_cuda', type=bool, help='enable/disable CUDA.')
22
+ parser.add_argument('--num_samples', type=int, default=1, help='Number of datasamples used for inference.')
23
+ parser.add_argument('--search_depth', type=int, default=3, help='Search granularity. Increasing this increases the run-time exponentially.')
24
+
25
+ # load config
26
+ args = parser.parse_args()
27
+ config = load_config(args.config_path)
28
+
29
+ # setup audio processor
30
+ ap = AudioProcessor(**config.audio)
31
+
32
+ # load dataset
33
+ _, train_data = load_wav_data(args.data_path, 0)
34
+ train_data = train_data[:args.num_samples]
35
+ dataset = WaveGradDataset(ap=ap,
36
+ items=train_data,
37
+ seq_len=-1,
38
+ hop_len=ap.hop_length,
39
+ pad_short=config.pad_short,
40
+ conv_pad=config.conv_pad,
41
+ is_training=True,
42
+ return_segments=False,
43
+ use_noise_augment=False,
44
+ use_cache=False,
45
+ verbose=True)
46
+ loader = DataLoader(
47
+ dataset,
48
+ batch_size=1,
49
+ shuffle=False,
50
+ collate_fn=dataset.collate_full_clips,
51
+ drop_last=False,
52
+ num_workers=config.num_loader_workers,
53
+ pin_memory=False)
54
+
55
+ # setup the model
56
+ model = setup_generator(config)
57
+ if args.use_cuda:
58
+ model.cuda()
59
+
60
+ # setup optimization parameters
61
+ base_values = sorted(10 * np.random.uniform(size=args.search_depth))
62
+ print(base_values)
63
+ exponents = 10 ** np.linspace(-6, -1, num=args.num_iter)
64
+ best_error = float('inf')
65
+ best_schedule = None
66
+ total_search_iter = len(base_values)**args.num_iter
67
+ for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter):
68
+ beta = exponents * base
69
+ model.compute_noise_level(beta)
70
+ for data in loader:
71
+ mel, audio = data
72
+ y_hat = model.inference(mel.cuda() if args.use_cuda else mel)
73
+
74
+ if args.use_cuda:
75
+ y_hat = y_hat.cpu()
76
+ y_hat = y_hat.numpy()
77
+
78
+ mel_hat = []
79
+ for i in range(y_hat.shape[0]):
80
+ m = ap.melspectrogram(y_hat[i, 0])[:, :-1]
81
+ mel_hat.append(torch.from_numpy(m))
82
+
83
+ mel_hat = torch.stack(mel_hat)
84
+ mse = torch.sum((mel - mel_hat) ** 2).mean()
85
+ if mse.item() < best_error:
86
+ best_error = mse.item()
87
+ best_schedule = {'beta': beta}
88
+ print(f" > Found a better schedule. - MSE: {mse.item()}")
89
+ np.save(args.output_path, best_schedule)
90
+
91
+
TTS/server/README.md ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## TTS example web-server
2
+
3
+ You'll need a model package (Zip file, includes TTS Python wheel, model files, server configuration, and optional nginx/uwsgi configs). Publicly available models are listed [here](https://github.com/mozilla/TTS/wiki/Released-Models#simple-packaging---self-contained-package-that-runs-an-http-api-for-a-pre-trained-tts-model).
4
+
5
+ Instructions below are based on a Ubuntu 18.04 machine, but it should be simple to adapt the package names to other distros if needed. Python 3.6 is recommended, as some of the dependencies' versions predate Python 3.7 and will force building from source, which requires extra dependencies and is not guaranteed to work.
6
+
7
+ #### Development server:
8
+
9
+ ##### Using server.py
10
+ If you have the environment set already for TTS, then you can directly call ```server.py```.
11
+
12
+ **Note:** After installing TTS as a package you can use ```tts-server``` to call the commands below.
13
+
14
+ Examples runs:
15
+
16
+ List officially released models.
17
+ ```python TTS/server/server.py --list_models ```
18
+
19
+ Run the server with the official models.
20
+ ```python TTS/server/server.py --model_name tts_models/en/ljspeech/tacotron2-DCA --vocoder_name vocoder_models/en/ljspeech/mulitband-melgan```
21
+
22
+ Run the server with the official models on a GPU.
23
+ ```CUDA_VISIBLE_DEVICES="0" python TTS/server/server.py --model_name tts_models/en/ljspeech/tacotron2-DCA --vocoder_name vocoder_models/en/ljspeech/mulitband-melgan --use_cuda True```
24
+
25
+ Run the server with a custom models.
26
+ ```python TTS/server/server.py --tts_checkpoint /path/to/tts/model.pth.tar --tts_config /path/to/tts/config.json --vocoder_checkpoint /path/to/vocoder/model.pth.tar --vocoder_config /path/to/vocoder/config.json```
27
+
28
+ ##### Using .whl
29
+ 1. apt-get install -y espeak libsndfile1 python3-venv
30
+ 2. python3 -m venv /tmp/venv
31
+ 3. source /tmp/venv/bin/activate
32
+ 4. pip install -U pip setuptools wheel
33
+ 5. pip install -U https//example.com/url/to/python/package.whl
34
+ 6. python -m TTS.server.server
35
+
36
+ You can now open http://localhost:5002 in a browser
37
+
38
+ #### Running with nginx/uwsgi:
39
+
40
+ **Note:** This method uses an old TTS model, so quality might be low.
41
+
42
+ 1. apt-get install -y uwsgi uwsgi-plugin-python3 nginx espeak libsndfile1 python3-venv
43
+ 2. python3 -m venv /tmp/venv
44
+ 3. source /tmp/venv/bin/activate
45
+ 4. pip install -U pip setuptools wheel
46
+ 5. pip install -U https//example.com/url/to/python/package.whl
47
+ 6. curl -LO https://github.com/reuben/TTS/releases/download/t2-ljspeech-mold/t2-ljspeech-mold-nginx-uwsgi.zip
48
+ 7. unzip *-nginx-uwsgi.zip
49
+ 8. cp tts_site_nginx /etc/nginx/sites-enabled/default
50
+ 9. service nginx restart
51
+ 10. uwsgi --ini uwsgi.ini
52
+
53
+ You can now open http://localhost:80 in a browser (edit the port in /etc/nginx/sites-enabled/tts_site_nginx).
54
+ Configure number of workers (number of requests that will be processed in parallel) by editing the `uwsgi.ini` file, specifically the `processes` setting.
55
+
56
+ #### Creating a server package with an embedded model
57
+
58
+ [setup.py](../setup.py) was extended with two new parameters when running the `bdist_wheel` command:
59
+
60
+ - `--checkpoint <path to checkpoint file>` - path to model checkpoint file you want to embed in the package
61
+ - `--model_config <path to config.json file>` - path to corresponding config.json file for the checkpoint
62
+
63
+ To create a package, run `python setup.py bdist_wheel --checkpoint /path/to/checkpoint --model_config /path/to/config.json`.
64
+
65
+ A Python `.whl` file will be created in the `dist/` folder with the checkpoint and config embedded in it.
TTS/server/__init__.py ADDED
File without changes
TTS/server/conf.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "tts_path":"/media/erogol/data_ssd/Models/libri_tts/5049/", // tts model root folder
3
+ "tts_file":"best_model.pth.tar", // tts checkpoint file
4
+ "tts_config":"config.json", // tts config.json file
5
+ "tts_speakers": null, // json file listing speaker ids. null if no speaker embedding.
6
+ "vocoder_config":null,
7
+ "vocoder_file": null,
8
+ "is_wavernn_batched":true,
9
+ "port": 5002,
10
+ "use_cuda": true,
11
+ "debug": true
12
+ }
TTS/server/server.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!flask/bin/python
2
+ import argparse
3
+ import os
4
+ import sys
5
+ import io
6
+ from pathlib import Path
7
+
8
+ from flask import Flask, render_template, request, send_file
9
+ from TTS.utils.synthesizer import Synthesizer
10
+ from TTS.utils.manage import ModelManager
11
+ from TTS.utils.io import load_config
12
+
13
+
14
+ def create_argparser():
15
+ def convert_boolean(x):
16
+ return x.lower() in ['true', '1', 'yes']
17
+
18
+ parser = argparse.ArgumentParser()
19
+ parser.add_argument('--list_models', type=convert_boolean, nargs='?', const=True, default=False, help='list available pre-trained tts and vocoder models.')
20
+ parser.add_argument('--model_name', type=str, help='name of one of the released tts models.')
21
+ parser.add_argument('--vocoder_name', type=str, help='name of one of the released vocoder models.')
22
+ parser.add_argument('--tts_checkpoint', type=str, help='path to custom tts checkpoint file')
23
+ parser.add_argument('--tts_config', type=str, help='path to custom tts config.json file')
24
+ parser.add_argument('--tts_speakers', type=str, help='path to JSON file containing speaker ids, if speaker ids are used in the model')
25
+ parser.add_argument('--vocoder_config', type=str, default=None, help='path to vocoder config file.')
26
+ parser.add_argument('--vocoder_checkpoint', type=str, default=None, help='path to vocoder checkpoint file.')
27
+ parser.add_argument('--port', type=int, default=5002, help='port to listen on.')
28
+ parser.add_argument('--use_cuda', type=convert_boolean, default=False, help='true to use CUDA.')
29
+ parser.add_argument('--debug', type=convert_boolean, default=False, help='true to enable Flask debug mode.')
30
+ parser.add_argument('--show_details', type=convert_boolean, default=False, help='Generate model detail page.')
31
+ return parser
32
+
33
+ synthesizer = None
34
+
35
+ embedded_models_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'model')
36
+
37
+ embedded_tts_folder = os.path.join(embedded_models_folder, 'tts')
38
+ tts_checkpoint_file = os.path.join(embedded_tts_folder, 'checkpoint.pth.tar')
39
+ tts_config_file = os.path.join(embedded_tts_folder, 'config.json')
40
+
41
+ embedded_vocoder_folder = os.path.join(embedded_models_folder, 'vocoder')
42
+ vocoder_checkpoint_file = os.path.join(embedded_vocoder_folder, 'checkpoint.pth.tar')
43
+ vocoder_config_file = os.path.join(embedded_vocoder_folder, 'config.json')
44
+
45
+ # These models are soon to be deprecated
46
+ embedded_wavernn_folder = os.path.join(embedded_models_folder, 'wavernn')
47
+ wavernn_checkpoint_file = os.path.join(embedded_wavernn_folder, 'checkpoint.pth.tar')
48
+ wavernn_config_file = os.path.join(embedded_wavernn_folder, 'config.json')
49
+
50
+ args = create_argparser().parse_args()
51
+
52
+ path = Path(__file__).parent / "../.models.json"
53
+ manager = ModelManager(path)
54
+
55
+ if args.list_models:
56
+ manager.list_models()
57
+ sys.exit()
58
+
59
+ # set models by the released models
60
+ if args.model_name is not None:
61
+ tts_checkpoint_file, tts_config_file = manager.download_model(args.model_name)
62
+
63
+ if args.vocoder_name is not None:
64
+ vocoder_checkpoint_file, vocoder_config_file = manager.download_model(args.vocoder_name)
65
+
66
+ # If these were not specified in the CLI args, use default values with embedded model files
67
+ if not args.tts_checkpoint and os.path.isfile(tts_checkpoint_file):
68
+ args.tts_checkpoint = tts_checkpoint_file
69
+ if not args.tts_config and os.path.isfile(tts_config_file):
70
+ args.tts_config = tts_config_file
71
+
72
+ if not args.vocoder_checkpoint and os.path.isfile(vocoder_checkpoint_file):
73
+ args.vocoder_checkpoint = vocoder_checkpoint_file
74
+ if not args.vocoder_config and os.path.isfile(vocoder_config_file):
75
+ args.vocoder_config = vocoder_config_file
76
+
77
+ synthesizer = Synthesizer(args.tts_checkpoint, args.tts_config, args.vocoder_checkpoint, args.vocoder_config, args.use_cuda)
78
+
79
+ app = Flask(__name__)
80
+
81
+
82
+ @app.route('/')
83
+ def index():
84
+ return render_template('index.html', show_details=args.show_details)
85
+
86
+ @app.route('/details')
87
+ def details():
88
+ model_config = load_config(args.tts_config)
89
+ if args.vocoder_config is not None and os.path.isfile(args.vocoder_config):
90
+ vocoder_config = load_config(args.vocoder_config)
91
+ else:
92
+ vocoder_config = None
93
+
94
+ return render_template('details.html',
95
+ show_details=args.show_details
96
+ , model_config=model_config
97
+ , vocoder_config=vocoder_config
98
+ , args=args.__dict__
99
+ )
100
+
101
+ @app.route('/api/tts', methods=['GET'])
102
+ def tts():
103
+ text = request.args.get('text')
104
+ print(" > Model input: {}".format(text))
105
+ wavs = synthesizer.tts(text)
106
+ out = io.BytesIO()
107
+ synthesizer.save_wav(wavs, out)
108
+ return send_file(out, mimetype='audio/wav')
109
+
110
+
111
+ def main():
112
+ app.run(debug=args.debug, host='0.0.0.0', port=args.port)
113
+
114
+
115
+ if __name__ == '__main__':
116
+ main()
TTS/server/static/TTS_circle.png ADDED
TTS/server/templates/details.html ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+
4
+ <head>
5
+
6
+ <meta charset="utf-8">
7
+ <meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
8
+ <meta name="description" content="">
9
+ <meta name="author" content="">
10
+
11
+ <title>TTS engine</title>
12
+
13
+ <!-- Bootstrap core CSS -->
14
+ <link href="https://stackpath.bootstrapcdn.com/bootstrap/4.1.1/css/bootstrap.min.css"
15
+ integrity="sha384-WskhaSGFgHYWDcbwN70/dfYBj47jz9qbsMId/iRN3ewGhXQFZCSftd1LZCfmhktB" crossorigin="anonymous"
16
+ rel="stylesheet">
17
+
18
+ <!-- Custom styles for this template -->
19
+ <style>
20
+ body {
21
+ padding-top: 54px;
22
+ }
23
+
24
+ @media (min-width: 992px) {
25
+ body {
26
+ padding-top: 56px;
27
+ }
28
+ }
29
+ </style>
30
+ </head>
31
+
32
+ <body>
33
+ <a href="https://github.com/mozilla/TTS"><img style="position: absolute; z-index:1000; top: 0; left: 0; border: 0;"
34
+ src="https://s3.amazonaws.com/github/ribbons/forkme_left_darkblue_121621.png" alt="Fork me on GitHub"></a>
35
+
36
+ {% if show_details == true %}
37
+
38
+ <div class="container">
39
+ <b>Model details</b>
40
+ </div>
41
+
42
+ <div class="container">
43
+ <details>
44
+ <summary>CLI arguments:</summary>
45
+ <table border="1" align="center" width="75%">
46
+ <tr>
47
+ <td> CLI key </td>
48
+ <td> Value </td>
49
+ </tr>
50
+
51
+ {% for key, value in args.items() %}
52
+
53
+ <tr>
54
+ <td>{{ key }}</td>
55
+ <td>{{ value }}</td>
56
+ </tr>
57
+
58
+ {% endfor %}
59
+ </table>
60
+ </details>
61
+ </div></br>
62
+
63
+ <div class="container">
64
+
65
+ {% if model_config != None %}
66
+
67
+ <details>
68
+ <summary>Model config:</summary>
69
+
70
+ <table border="1" align="center" width="75%">
71
+ <tr>
72
+ <td> Key </td>
73
+ <td> Value </td>
74
+ </tr>
75
+
76
+
77
+ {% for key, value in model_config.items() %}
78
+
79
+ <tr>
80
+ <td>{{ key }}</td>
81
+ <td>{{ value }}</td>
82
+ </tr>
83
+
84
+ {% endfor %}
85
+
86
+ </table>
87
+ </details>
88
+
89
+ {% endif %}
90
+
91
+ </div></br>
92
+
93
+
94
+
95
+ <div class="container">
96
+ {% if vocoder_config != None %}
97
+ <details>
98
+ <summary>Vocoder model config:</summary>
99
+
100
+ <table border="1" align="center" width="75%">
101
+ <tr>
102
+ <td> Key </td>
103
+ <td> Value </td>
104
+ </tr>
105
+
106
+
107
+ {% for key, value in vocoder_config.items() %}
108
+
109
+ <tr>
110
+ <td>{{ key }}</td>
111
+ <td>{{ value }}</td>
112
+ </tr>
113
+
114
+ {% endfor %}
115
+
116
+
117
+ </table>
118
+ </details>
119
+ {% endif %}
120
+ </div></br>
121
+
122
+ {% else %}
123
+ <div class="container">
124
+ <b>Please start server with --show_details=true to see details.</b>
125
+ </div>
126
+
127
+ {% endif %}
128
+
129
+ </body>
130
+
131
+ </html>
TTS/server/templates/index.html ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+
4
+ <head>
5
+
6
+ <meta charset="utf-8">
7
+ <meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
8
+ <meta name="description" content="">
9
+ <meta name="author" content="">
10
+
11
+ <title>TTS engine</title>
12
+
13
+ <!-- Bootstrap core CSS -->
14
+ <link href="https://stackpath.bootstrapcdn.com/bootstrap/4.1.1/css/bootstrap.min.css"
15
+ integrity="sha384-WskhaSGFgHYWDcbwN70/dfYBj47jz9qbsMId/iRN3ewGhXQFZCSftd1LZCfmhktB" crossorigin="anonymous" rel="stylesheet">
16
+
17
+ <!-- Custom styles for this template -->
18
+ <style>
19
+ body {
20
+ padding-top: 54px;
21
+ }
22
+ @media (min-width: 992px) {
23
+ body {
24
+ padding-top: 56px;
25
+ }
26
+ }
27
+
28
+ </style>
29
+ </head>
30
+
31
+ <body>
32
+ <a href="https://github.com/mozilla/TTS"><img style="position: absolute; z-index:1000; top: 0; left: 0; border: 0;" src="https://s3.amazonaws.com/github/ribbons/forkme_left_darkblue_121621.png" alt="Fork me on GitHub"></a>
33
+
34
+ <!-- Navigation -->
35
+ <!--
36
+ <nav class="navbar navbar-expand-lg navbar-dark bg-dark fixed-top">
37
+ <div class="container">
38
+ <a class="navbar-brand" href="#">Mozilla TTS</a>
39
+ <button class="navbar-toggler" type="button" data-toggle="collapse" data-target="#navbarResponsive" aria-controls="navbarResponsive" aria-expanded="false" aria-label="Toggle navigation">
40
+ <span class="navbar-toggler-icon"></span>
41
+ </button>
42
+ <div class="collapse navbar-collapse" id="navbarResponsive">
43
+ <ul class="navbar-nav ml-auto">
44
+ <li class="nav-item active">
45
+ <a class="nav-link" href="#">Home
46
+ <span class="sr-only">(current)</span>
47
+ </a>
48
+ </li>
49
+ </ul>
50
+ </div>
51
+ </div>
52
+ </nav>
53
+ -->
54
+
55
+ <!-- Page Content -->
56
+ <div class="container">
57
+ <div class="row">
58
+ <div class="col-lg-12 text-center">
59
+ <img class="mt-5" src="{{url_for('static', filename='TTS_circle.png')}}" align="middle" />
60
+
61
+ <ul class="list-unstyled">
62
+ </ul>
63
+ <input id="text" placeholder="Type here..." size=45 type="text" name="text">
64
+ <button id="speak-button" name="speak">Speak</button><br/><br/>
65
+ {%if show_details%}
66
+ <button id="details-button" onclick="location.href = 'details'" name="model-details">Model Details</button><br/><br/>
67
+ {%endif%}
68
+ <audio id="audio" controls autoplay hidden></audio>
69
+ <p id="message"></p>
70
+ </div>
71
+ </div>
72
+ </div>
73
+
74
+ <!-- Bootstrap core JavaScript -->
75
+ <script>
76
+ function q(selector) {return document.querySelector(selector)}
77
+ q('#text').focus()
78
+ function do_tts(e) {
79
+ text = q('#text').value
80
+ if (text) {
81
+ q('#message').textContent = 'Synthesizing...'
82
+ q('#speak-button').disabled = true
83
+ q('#audio').hidden = true
84
+ synthesize(text)
85
+ }
86
+ e.preventDefault()
87
+ return false
88
+ }
89
+ q('#speak-button').addEventListener('click', do_tts)
90
+ q('#text').addEventListener('keyup', function(e) {
91
+ if (e.keyCode == 13) { // enter
92
+ do_tts(e)
93
+ }
94
+ })
95
+ function synthesize(text) {
96
+ fetch('/api/tts?text=' + encodeURIComponent(text), {cache: 'no-cache'})
97
+ .then(function(res) {
98
+ if (!res.ok) throw Error(res.statusText)
99
+ return res.blob()
100
+ }).then(function(blob) {
101
+ q('#message').textContent = ''
102
+ q('#speak-button').disabled = false
103
+ q('#audio').src = URL.createObjectURL(blob)
104
+ q('#audio').hidden = false
105
+ }).catch(function(err) {
106
+ q('#message').textContent = 'Error: ' + err.message
107
+ q('#speak-button').disabled = false
108
+ })
109
+ }
110
+ </script>
111
+
112
+ </body>
113
+
114
+ </html>
TTS/speaker_encoder/README.md ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Speaker Encoder
2
+
3
+ This is an implementation of https://arxiv.org/abs/1710.10467. This model can be used for voice and speaker embedding.
4
+
5
+ With the code here you can generate d-vectors for both multi-speaker and single-speaker TTS datasets, then visualise and explore them along with the associated audio files in an interactive chart.
6
+
7
+ Below is an example showing embedding results of various speakers. You can generate the same plot with the provided notebook as demonstrated in [this video](https://youtu.be/KW3oO7JVa7Q).
8
+
9
+ ![](umap.png)
10
+
11
+ Download a pretrained model from [Released Models](https://github.com/mozilla/TTS/wiki/Released-Models) page.
12
+
13
+ To run the code, you need to follow the same flow as in TTS.
14
+
15
+ - Define 'config.json' for your needs. Note that, audio parameters should match your TTS model.
16
+ - Example training call ```python speaker_encoder/train.py --config_path speaker_encoder/config.json --data_path ~/Data/Libri-TTS/train-clean-360```
17
+ - Generate embedding vectors ```python speaker_encoder/compute_embeddings.py --use_cuda true /model/path/best_model.pth.tar model/config/path/config.json dataset/path/ output_path``` . This code parses all .wav files at the given dataset path and generates the same folder structure under the output path with the generated embedding files.
18
+ - Watch training on Tensorboard as in TTS
TTS/speaker_encoder/__init__.py ADDED
File without changes
TTS/speaker_encoder/config.json ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ {
3
+ "run_name": "mueller91",
4
+ "run_description": "train speaker encoder with voxceleb1, voxceleb2 and libriSpeech ",
5
+ "audio":{
6
+ // Audio processing parameters
7
+ "num_mels": 40, // size of the mel spec frame.
8
+ "fft_size": 400, // number of stft frequency levels. Size of the linear spectogram frame.
9
+ "sample_rate": 16000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
10
+ "win_length": 400, // stft window length in ms.
11
+ "hop_length": 160, // stft window hop-lengh in ms.
12
+ "frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
13
+ "frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
14
+ "preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
15
+ "min_level_db": -100, // normalization range
16
+ "ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
17
+ "power": 1.5, // value to sharpen wav signals after GL algorithm.
18
+ "griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
19
+ // Normalization parameters
20
+ "signal_norm": true, // normalize the spec values in range [0, 1]
21
+ "symmetric_norm": true, // move normalization to range [-1, 1]
22
+ "max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
23
+ "clip_norm": true, // clip normalized values into the range.
24
+ "mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
25
+ "mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
26
+ "do_trim_silence": true, // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
27
+ "trim_db": 60 // threshold for timming silence. Set this according to your dataset.
28
+ },
29
+ "reinit_layers": [],
30
+ "loss": "angleproto", // "ge2e" to use Generalized End-to-End loss and "angleproto" to use Angular Prototypical loss (new SOTA)
31
+ "grad_clip": 3.0, // upper limit for gradients for clipping.
32
+ "epochs": 1000, // total number of epochs to train.
33
+ "lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
34
+ "lr_decay": false, // if true, Noam learning rate decaying is applied through training.
35
+ "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
36
+ "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
37
+ "steps_plot_stats": 10, // number of steps to plot embeddings.
38
+ "num_speakers_in_batch": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
39
+ "num_utters_per_speaker": 10, //
40
+ "num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values.
41
+ "wd": 0.000001, // Weight decay weight.
42
+ "checkpoint": true, // If true, it saves checkpoints per "save_step"
43
+ "save_step": 1000, // Number of training steps expected to save traning stats and checkpoints.
44
+ "print_step": 20, // Number of steps to log traning on console.
45
+ "output_path": "../../MozillaTTSOutput/checkpoints/voxceleb_librispeech/speaker_encoder/", // DATASET-RELATED: output path for all training outputs.
46
+ "model": {
47
+ "input_dim": 40,
48
+ "proj_dim": 256,
49
+ "lstm_dim": 768,
50
+ "num_lstm_layers": 3,
51
+ "use_lstm_with_projection": true
52
+ },
53
+ "storage": {
54
+ "sample_from_storage_p": 0.66, // the probability with which we'll sample from the DataSet in-memory storage
55
+ "storage_size": 15, // the size of the in-memory storage with respect to a single batch
56
+ "additive_noise": 1e-5 // add very small gaussian noise to the data in order to increase robustness
57
+ },
58
+ "datasets":
59
+ [
60
+ {
61
+ "name": "vctk_slim",
62
+ "path": "../../../audio-datasets/en/VCTK-Corpus/",
63
+ "meta_file_train": null,
64
+ "meta_file_val": null
65
+ },
66
+ {
67
+ "name": "libri_tts",
68
+ "path": "../../../audio-datasets/en/LibriTTS/train-clean-100",
69
+ "meta_file_train": null,
70
+ "meta_file_val": null
71
+ },
72
+ {
73
+ "name": "libri_tts",
74
+ "path": "../../../audio-datasets/en/LibriTTS/train-clean-360",
75
+ "meta_file_train": null,
76
+ "meta_file_val": null
77
+ },
78
+ {
79
+ "name": "libri_tts",
80
+ "path": "../../../audio-datasets/en/LibriTTS/train-other-500",
81
+ "meta_file_train": null,
82
+ "meta_file_val": null
83
+ },
84
+ {
85
+ "name": "voxceleb1",
86
+ "path": "../../../audio-datasets/en/voxceleb1/",
87
+ "meta_file_train": null,
88
+ "meta_file_val": null
89
+ },
90
+ {
91
+ "name": "voxceleb2",
92
+ "path": "../../../audio-datasets/en/voxceleb2/",
93
+ "meta_file_train": null,
94
+ "meta_file_val": null
95
+ },
96
+ {
97
+ "name": "common_voice",
98
+ "path": "../../../audio-datasets/en/MozillaCommonVoice",
99
+ "meta_file_train": "train.tsv",
100
+ "meta_file_val": "test.tsv"
101
+ }
102
+ ]
103
+ }
TTS/speaker_encoder/dataset.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy
2
+ import numpy as np
3
+ import queue
4
+ import torch
5
+ import random
6
+ from torch.utils.data import Dataset
7
+ from tqdm import tqdm
8
+
9
+
10
+ class MyDataset(Dataset):
11
+ def __init__(self, ap, meta_data, voice_len=1.6, num_speakers_in_batch=64,
12
+ storage_size=1, sample_from_storage_p=0.5, additive_noise=0,
13
+ num_utter_per_speaker=10, skip_speakers=False, verbose=False):
14
+ """
15
+ Args:
16
+ ap (TTS.tts.utils.AudioProcessor): audio processor object.
17
+ meta_data (list): list of dataset instances.
18
+ seq_len (int): voice segment length in seconds.
19
+ verbose (bool): print diagnostic information.
20
+ """
21
+ self.items = meta_data
22
+ self.sample_rate = ap.sample_rate
23
+ self.voice_len = voice_len
24
+ self.seq_len = int(voice_len * self.sample_rate)
25
+ self.num_speakers_in_batch = num_speakers_in_batch
26
+ self.num_utter_per_speaker = num_utter_per_speaker
27
+ self.skip_speakers = skip_speakers
28
+ self.ap = ap
29
+ self.verbose = verbose
30
+ self.__parse_items()
31
+ self.storage = queue.Queue(maxsize=storage_size*num_speakers_in_batch)
32
+ self.sample_from_storage_p = float(sample_from_storage_p)
33
+ self.additive_noise = float(additive_noise)
34
+ if self.verbose:
35
+ print("\n > DataLoader initialization")
36
+ print(f" | > Speakers per Batch: {num_speakers_in_batch}")
37
+ print(f" | > Storage Size: {self.storage.maxsize} speakers, each with {num_utter_per_speaker} utters")
38
+ print(f" | > Sample_from_storage_p : {self.sample_from_storage_p}")
39
+ print(f" | > Noise added : {self.additive_noise}")
40
+ print(f" | > Number of instances : {len(self.items)}")
41
+ print(f" | > Sequence length: {self.seq_len}")
42
+ print(f" | > Num speakers: {len(self.speakers)}")
43
+
44
+ def load_wav(self, filename):
45
+ audio = self.ap.load_wav(filename, sr=self.ap.sample_rate)
46
+ return audio
47
+
48
+ def load_data(self, idx):
49
+ text, wav_file, speaker_name = self.items[idx]
50
+ wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
51
+ mel = self.ap.melspectrogram(wav).astype("float32")
52
+ # sample seq_len
53
+
54
+ assert text.size > 0, self.items[idx][1]
55
+ assert wav.size > 0, self.items[idx][1]
56
+
57
+ sample = {
58
+ "mel": mel,
59
+ "item_idx": self.items[idx][1],
60
+ "speaker_name": speaker_name,
61
+ }
62
+ return sample
63
+
64
+ def __parse_items(self):
65
+ self.speaker_to_utters = {}
66
+ for i in self.items:
67
+ path_ = i[1]
68
+ speaker_ = i[2]
69
+ if speaker_ in self.speaker_to_utters.keys():
70
+ self.speaker_to_utters[speaker_].append(path_)
71
+ else:
72
+ self.speaker_to_utters[speaker_] = [path_, ]
73
+
74
+ if self.skip_speakers:
75
+ self.speaker_to_utters = {k: v for (k, v) in self.speaker_to_utters.items() if
76
+ len(v) >= self.num_utter_per_speaker}
77
+
78
+ self.speakers = [k for (k, v) in self.speaker_to_utters.items()]
79
+
80
+ # def __parse_items(self):
81
+ # """
82
+ # Find unique speaker ids and create a dict mapping utterances from speaker id
83
+ # """
84
+ # speakers = list({item[-1] for item in self.items})
85
+ # self.speaker_to_utters = {}
86
+ # self.speakers = []
87
+ # for speaker in speakers:
88
+ # speaker_utters = [item[1] for item in self.items if item[2] == speaker]
89
+ # if len(speaker_utters) < self.num_utter_per_speaker and self.skip_speakers:
90
+ # print(
91
+ # f" [!] Skipped speaker {speaker}. Not enough utterances {self.num_utter_per_speaker} vs {len(speaker_utters)}."
92
+ # )
93
+ # else:
94
+ # self.speakers.append(speaker)
95
+ # self.speaker_to_utters[speaker] = speaker_utters
96
+
97
+ def __len__(self):
98
+ return int(1e10)
99
+
100
+ def __sample_speaker(self):
101
+ speaker = random.sample(self.speakers, 1)[0]
102
+ if self.num_utter_per_speaker > len(self.speaker_to_utters[speaker]):
103
+ utters = random.choices(
104
+ self.speaker_to_utters[speaker], k=self.num_utter_per_speaker
105
+ )
106
+ else:
107
+ utters = random.sample(
108
+ self.speaker_to_utters[speaker], self.num_utter_per_speaker
109
+ )
110
+ return speaker, utters
111
+
112
+ def __sample_speaker_utterances(self, speaker):
113
+ """
114
+ Sample all M utterances for the given speaker.
115
+ """
116
+ wavs = []
117
+ labels = []
118
+ for _ in range(self.num_utter_per_speaker):
119
+ # TODO:dummy but works
120
+ while True:
121
+ if len(self.speaker_to_utters[speaker]) > 0:
122
+ utter = random.sample(self.speaker_to_utters[speaker], 1)[0]
123
+ else:
124
+ self.speakers.remove(speaker)
125
+ speaker, _ = self.__sample_speaker()
126
+ continue
127
+ wav = self.load_wav(utter)
128
+ if wav.shape[0] - self.seq_len > 0:
129
+ break
130
+ self.speaker_to_utters[speaker].remove(utter)
131
+
132
+ wavs.append(wav)
133
+ labels.append(speaker)
134
+ return wavs, labels
135
+
136
+ def __getitem__(self, idx):
137
+ speaker, _ = self.__sample_speaker()
138
+ return speaker
139
+
140
+ def collate_fn(self, batch):
141
+ labels = []
142
+ feats = []
143
+ for speaker in batch:
144
+ if random.random() < self.sample_from_storage_p and self.storage.full():
145
+ # sample from storage (if full), ignoring the speaker
146
+ wavs_, labels_ = random.choice(self.storage.queue)
147
+ else:
148
+ # don't sample from storage, but from HDD
149
+ wavs_, labels_ = self.__sample_speaker_utterances(speaker)
150
+ # if storage is full, remove an item
151
+ if self.storage.full():
152
+ _ = self.storage.get_nowait()
153
+ # put the newly loaded item into storage
154
+ self.storage.put_nowait((wavs_, labels_))
155
+
156
+ # add random gaussian noise
157
+ if self.additive_noise > 0:
158
+ noises_ = [numpy.random.normal(0, self.additive_noise, size=len(w)) for w in wavs_]
159
+ wavs_ = [wavs_[i] + noises_[i] for i in range(len(wavs_))]
160
+
161
+ # get a random subset of each of the wavs and convert to MFCC.
162
+ offsets_ = [random.randint(0, wav.shape[0] - self.seq_len) for wav in wavs_]
163
+ mels_ = [self.ap.melspectrogram(wavs_[i][offsets_[i]: offsets_[i] + self.seq_len]) for i in range(len(wavs_))]
164
+ feats_ = [torch.FloatTensor(mel) for mel in mels_]
165
+
166
+ labels.append(labels_)
167
+ feats.extend(feats_)
168
+ feats = torch.stack(feats)
169
+ return feats.transpose(1, 2), labels
TTS/speaker_encoder/losses.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+ # adapted from https://github.com/cvqluu/GE2E-Loss
7
+ class GE2ELoss(nn.Module):
8
+ def __init__(self, init_w=10.0, init_b=-5.0, loss_method="softmax"):
9
+ """
10
+ Implementation of the Generalized End-to-End loss defined in https://arxiv.org/abs/1710.10467 [1]
11
+ Accepts an input of size (N, M, D)
12
+ where N is the number of speakers in the batch,
13
+ M is the number of utterances per speaker,
14
+ and D is the dimensionality of the embedding vector (e.g. d-vector)
15
+ Args:
16
+ - init_w (float): defines the initial value of w in Equation (5) of [1]
17
+ - init_b (float): definies the initial value of b in Equation (5) of [1]
18
+ """
19
+ super(GE2ELoss, self).__init__()
20
+ # pylint: disable=E1102
21
+ self.w = nn.Parameter(torch.tensor(init_w))
22
+ # pylint: disable=E1102
23
+ self.b = nn.Parameter(torch.tensor(init_b))
24
+ self.loss_method = loss_method
25
+
26
+ print(' > Initialised Generalized End-to-End loss')
27
+
28
+ assert self.loss_method in ["softmax", "contrast"]
29
+
30
+ if self.loss_method == "softmax":
31
+ self.embed_loss = self.embed_loss_softmax
32
+ if self.loss_method == "contrast":
33
+ self.embed_loss = self.embed_loss_contrast
34
+
35
+ # pylint: disable=R0201
36
+ def calc_new_centroids(self, dvecs, centroids, spkr, utt):
37
+ """
38
+ Calculates the new centroids excluding the reference utterance
39
+ """
40
+ excl = torch.cat((dvecs[spkr, :utt], dvecs[spkr, utt + 1 :]))
41
+ excl = torch.mean(excl, 0)
42
+ new_centroids = []
43
+ for i, centroid in enumerate(centroids):
44
+ if i == spkr:
45
+ new_centroids.append(excl)
46
+ else:
47
+ new_centroids.append(centroid)
48
+ return torch.stack(new_centroids)
49
+
50
+ def calc_cosine_sim(self, dvecs, centroids):
51
+ """
52
+ Make the cosine similarity matrix with dims (N,M,N)
53
+ """
54
+ cos_sim_matrix = []
55
+ for spkr_idx, speaker in enumerate(dvecs):
56
+ cs_row = []
57
+ for utt_idx, utterance in enumerate(speaker):
58
+ new_centroids = self.calc_new_centroids(
59
+ dvecs, centroids, spkr_idx, utt_idx
60
+ )
61
+ # vector based cosine similarity for speed
62
+ cs_row.append(
63
+ torch.clamp(
64
+ torch.mm(
65
+ utterance.unsqueeze(1).transpose(0, 1),
66
+ new_centroids.transpose(0, 1),
67
+ )
68
+ / (torch.norm(utterance) * torch.norm(new_centroids, dim=1)),
69
+ 1e-6,
70
+ )
71
+ )
72
+ cs_row = torch.cat(cs_row, dim=0)
73
+ cos_sim_matrix.append(cs_row)
74
+ return torch.stack(cos_sim_matrix)
75
+
76
+ # pylint: disable=R0201
77
+ def embed_loss_softmax(self, dvecs, cos_sim_matrix):
78
+ """
79
+ Calculates the loss on each embedding $L(e_{ji})$ by taking softmax
80
+ """
81
+ N, M, _ = dvecs.shape
82
+ L = []
83
+ for j in range(N):
84
+ L_row = []
85
+ for i in range(M):
86
+ L_row.append(-F.log_softmax(cos_sim_matrix[j, i], 0)[j])
87
+ L_row = torch.stack(L_row)
88
+ L.append(L_row)
89
+ return torch.stack(L)
90
+
91
+ # pylint: disable=R0201
92
+ def embed_loss_contrast(self, dvecs, cos_sim_matrix):
93
+ """
94
+ Calculates the loss on each embedding $L(e_{ji})$ by contrast loss with closest centroid
95
+ """
96
+ N, M, _ = dvecs.shape
97
+ L = []
98
+ for j in range(N):
99
+ L_row = []
100
+ for i in range(M):
101
+ centroids_sigmoids = torch.sigmoid(cos_sim_matrix[j, i])
102
+ excl_centroids_sigmoids = torch.cat(
103
+ (centroids_sigmoids[:j], centroids_sigmoids[j + 1 :])
104
+ )
105
+ L_row.append(
106
+ 1.0
107
+ - torch.sigmoid(cos_sim_matrix[j, i, j])
108
+ + torch.max(excl_centroids_sigmoids)
109
+ )
110
+ L_row = torch.stack(L_row)
111
+ L.append(L_row)
112
+ return torch.stack(L)
113
+
114
+ def forward(self, dvecs):
115
+ """
116
+ Calculates the GE2E loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
117
+ """
118
+ centroids = torch.mean(dvecs, 1)
119
+ cos_sim_matrix = self.calc_cosine_sim(dvecs, centroids)
120
+ torch.clamp(self.w, 1e-6)
121
+ cos_sim_matrix = self.w * cos_sim_matrix + self.b
122
+ L = self.embed_loss(dvecs, cos_sim_matrix)
123
+ return L.mean()
124
+
125
+ # adapted from https://github.com/clovaai/voxceleb_trainer/blob/master/loss/angleproto.py
126
+ class AngleProtoLoss(nn.Module):
127
+ """
128
+ Implementation of the Angular Prototypical loss defined in https://arxiv.org/abs/2003.11982
129
+ Accepts an input of size (N, M, D)
130
+ where N is the number of speakers in the batch,
131
+ M is the number of utterances per speaker,
132
+ and D is the dimensionality of the embedding vector
133
+ Args:
134
+ - init_w (float): defines the initial value of w
135
+ - init_b (float): definies the initial value of b
136
+ """
137
+ def __init__(self, init_w=10.0, init_b=-5.0):
138
+ super(AngleProtoLoss, self).__init__()
139
+ # pylint: disable=E1102
140
+ self.w = nn.Parameter(torch.tensor(init_w))
141
+ # pylint: disable=E1102
142
+ self.b = nn.Parameter(torch.tensor(init_b))
143
+ self.criterion = torch.nn.CrossEntropyLoss()
144
+
145
+ print(' > Initialised Angular Prototypical loss')
146
+
147
+ def forward(self, x):
148
+ """
149
+ Calculates the AngleProto loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
150
+ """
151
+ out_anchor = torch.mean(x[:, 1:, :], 1)
152
+ out_positive = x[:, 0, :]
153
+ num_speakers = out_anchor.size()[0]
154
+
155
+ cos_sim_matrix = F.cosine_similarity(out_positive.unsqueeze(-1).expand(-1, -1, num_speakers), out_anchor.unsqueeze(-1).expand(-1, -1, num_speakers).transpose(0, 2))
156
+ torch.clamp(self.w, 1e-6)
157
+ cos_sim_matrix = cos_sim_matrix * self.w + self.b
158
+ label = torch.from_numpy(np.asarray(range(0, num_speakers))).to(cos_sim_matrix.device)
159
+ L = self.criterion(cos_sim_matrix, label)
160
+ return L
TTS/speaker_encoder/model.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class LSTMWithProjection(nn.Module):
6
+ def __init__(self, input_size, hidden_size, proj_size):
7
+ super().__init__()
8
+ self.input_size = input_size
9
+ self.hidden_size = hidden_size
10
+ self.proj_size = proj_size
11
+ self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
12
+ self.linear = nn.Linear(hidden_size, proj_size, bias=False)
13
+
14
+ def forward(self, x):
15
+ self.lstm.flatten_parameters()
16
+ o, (_, _) = self.lstm(x)
17
+ return self.linear(o)
18
+
19
+ class LSTMWithoutProjection(nn.Module):
20
+ def __init__(self, input_dim, lstm_dim, proj_dim, num_lstm_layers):
21
+ super().__init__()
22
+ self.lstm = nn.LSTM(input_size=input_dim,
23
+ hidden_size=lstm_dim,
24
+ num_layers=num_lstm_layers,
25
+ batch_first=True)
26
+ self.linear = nn.Linear(lstm_dim, proj_dim, bias=True)
27
+ self.relu = nn.ReLU()
28
+ def forward(self, x):
29
+ _, (hidden, _) = self.lstm(x)
30
+ return self.relu(self.linear(hidden[-1]))
31
+
32
+ class SpeakerEncoder(nn.Module):
33
+ def __init__(self, input_dim, proj_dim=256, lstm_dim=768, num_lstm_layers=3, use_lstm_with_projection=True):
34
+ super().__init__()
35
+ self.use_lstm_with_projection = use_lstm_with_projection
36
+ layers = []
37
+ # choise LSTM layer
38
+ if use_lstm_with_projection:
39
+ layers.append(LSTMWithProjection(input_dim, lstm_dim, proj_dim))
40
+ for _ in range(num_lstm_layers - 1):
41
+ layers.append(LSTMWithProjection(proj_dim, lstm_dim, proj_dim))
42
+ self.layers = nn.Sequential(*layers)
43
+ else:
44
+ self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers)
45
+
46
+ self._init_layers()
47
+
48
+ def _init_layers(self):
49
+ for name, param in self.layers.named_parameters():
50
+ if "bias" in name:
51
+ nn.init.constant_(param, 0.0)
52
+ elif "weight" in name:
53
+ nn.init.xavier_normal_(param)
54
+
55
+ def forward(self, x):
56
+ # TODO: implement state passing for lstms
57
+ d = self.layers(x)
58
+ if self.use_lstm_with_projection:
59
+ d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1)
60
+ else:
61
+ d = torch.nn.functional.normalize(d, p=2, dim=1)
62
+ return d
63
+
64
+ @torch.no_grad()
65
+ def inference(self, x):
66
+ d = self.layers.forward(x)
67
+ if self.use_lstm_with_projection:
68
+ d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1)
69
+ else:
70
+ d = torch.nn.functional.normalize(d, p=2, dim=1)
71
+ return d
72
+
73
+ def compute_embedding(self, x, num_frames=160, overlap=0.5):
74
+ """
75
+ Generate embeddings for a batch of utterances
76
+ x: 1xTxD
77
+ """
78
+ num_overlap = int(num_frames * overlap)
79
+ max_len = x.shape[1]
80
+ embed = None
81
+ cur_iter = 0
82
+ for offset in range(0, max_len, num_frames - num_overlap):
83
+ cur_iter += 1
84
+ end_offset = min(x.shape[1], offset + num_frames)
85
+ frames = x[:, offset:end_offset]
86
+ if embed is None:
87
+ embed = self.inference(frames)
88
+ else:
89
+ embed += self.inference(frames)
90
+ return embed / cur_iter
91
+
92
+ def batch_compute_embedding(self, x, seq_lens, num_frames=160, overlap=0.5):
93
+ """
94
+ Generate embeddings for a batch of utterances
95
+ x: BxTxD
96
+ """
97
+ num_overlap = num_frames * overlap
98
+ max_len = x.shape[1]
99
+ embed = None
100
+ num_iters = seq_lens / (num_frames - num_overlap)
101
+ cur_iter = 0
102
+ for offset in range(0, max_len, num_frames - num_overlap):
103
+ cur_iter += 1
104
+ end_offset = min(x.shape[1], offset + num_frames)
105
+ frames = x[:, offset:end_offset]
106
+ if embed is None:
107
+ embed = self.inference(frames)
108
+ else:
109
+ embed[cur_iter <= num_iters, :] += self.inference(
110
+ frames[cur_iter <= num_iters, :, :]
111
+ )
112
+ return embed / num_iters
TTS/speaker_encoder/requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ umap-learn
2
+ numpy>=1.17.0
TTS/speaker_encoder/umap.png ADDED