Spaces:
Running
on
Zero
Running
on
Zero
v0.5.0
Browse files- .gitignore +214 -0
- LICENSE +201 -0
- SoniTranslate_Colab.ipynb +124 -0
- app.py +2 -0
- app_rvc.py +0 -0
- assets/logo.jpeg +0 -0
- docs/windows_install.md +150 -0
- lib/audio.py +21 -0
- lib/infer_pack/attentions.py +417 -0
- lib/infer_pack/commons.py +166 -0
- lib/infer_pack/models.py +1142 -0
- lib/infer_pack/modules.py +522 -0
- lib/infer_pack/transforms.py +209 -0
- lib/rmvpe.py +422 -0
- mdx_models/data.json +354 -0
- packages.txt +3 -0
- pre-requirements.txt +15 -0
- requirements.txt +19 -0
- requirements_xtts.txt +58 -0
- soni_translate/audio_segments.py +141 -0
- soni_translate/language_configuration.py +551 -0
- soni_translate/languages_gui.py +0 -0
- soni_translate/logging_setup.py +68 -0
- soni_translate/mdx_net.py +582 -0
- soni_translate/postprocessor.py +229 -0
- soni_translate/preprocessor.py +308 -0
- soni_translate/speech_segmentation.py +499 -0
- soni_translate/text_multiformat_processor.py +987 -0
- soni_translate/text_to_speech.py +1574 -0
- soni_translate/translate_segments.py +457 -0
- soni_translate/utils.py +487 -0
- vci_pipeline.py +454 -0
- voice_main.py +732 -0
.gitignore
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib64/
|
18 |
+
parts/
|
19 |
+
sdist/
|
20 |
+
var/
|
21 |
+
wheels/
|
22 |
+
share/python-wheels/
|
23 |
+
*.egg-info/
|
24 |
+
.installed.cfg
|
25 |
+
*.egg
|
26 |
+
MANIFEST
|
27 |
+
|
28 |
+
# PyInstaller
|
29 |
+
# Usually these files are written by a python script from a template
|
30 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
31 |
+
*.manifest
|
32 |
+
*.spec
|
33 |
+
|
34 |
+
# Installer logs
|
35 |
+
pip-log.txt
|
36 |
+
pip-delete-this-directory.txt
|
37 |
+
|
38 |
+
# Unit test / coverage reports
|
39 |
+
htmlcov/
|
40 |
+
.tox/
|
41 |
+
.nox/
|
42 |
+
.coverage
|
43 |
+
.coverage.*
|
44 |
+
.cache
|
45 |
+
nosetests.xml
|
46 |
+
coverage.xml
|
47 |
+
*.cover
|
48 |
+
*.py,cover
|
49 |
+
.hypothesis/
|
50 |
+
.pytest_cache/
|
51 |
+
cover/
|
52 |
+
|
53 |
+
# Translations
|
54 |
+
*.mo
|
55 |
+
*.pot
|
56 |
+
|
57 |
+
# Django stuff:
|
58 |
+
*.log
|
59 |
+
local_settings.py
|
60 |
+
db.sqlite3
|
61 |
+
db.sqlite3-journal
|
62 |
+
|
63 |
+
# Flask stuff:
|
64 |
+
instance/
|
65 |
+
.webassets-cache
|
66 |
+
|
67 |
+
# Scrapy stuff:
|
68 |
+
.scrapy
|
69 |
+
|
70 |
+
# Sphinx documentation
|
71 |
+
docs/_build/
|
72 |
+
|
73 |
+
# PyBuilder
|
74 |
+
.pybuilder/
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
# For a library or package, you might want to ignore these files since the code is
|
86 |
+
# intended to run in multiple environments; otherwise, check them in:
|
87 |
+
# .python-version
|
88 |
+
|
89 |
+
# pipenv
|
90 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
91 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
92 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
93 |
+
# install all needed dependencies.
|
94 |
+
#Pipfile.lock
|
95 |
+
|
96 |
+
# poetry
|
97 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
98 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
99 |
+
# commonly ignored for libraries.
|
100 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
101 |
+
#poetry.lock
|
102 |
+
|
103 |
+
# pdm
|
104 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
105 |
+
#pdm.lock
|
106 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
107 |
+
# in version control.
|
108 |
+
# https://pdm.fming.dev/#use-with-ide
|
109 |
+
.pdm.toml
|
110 |
+
|
111 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
112 |
+
__pypackages__/
|
113 |
+
|
114 |
+
# Celery stuff
|
115 |
+
celerybeat-schedule
|
116 |
+
celerybeat.pid
|
117 |
+
|
118 |
+
# SageMath parsed files
|
119 |
+
*.sage.py
|
120 |
+
|
121 |
+
# Environments
|
122 |
+
.env
|
123 |
+
.venv
|
124 |
+
env/
|
125 |
+
venv/
|
126 |
+
ENV/
|
127 |
+
env.bak/
|
128 |
+
venv.bak/
|
129 |
+
|
130 |
+
# Spyder project settings
|
131 |
+
.spyderproject
|
132 |
+
.spyproject
|
133 |
+
|
134 |
+
# Rope project settings
|
135 |
+
.ropeproject
|
136 |
+
|
137 |
+
# mkdocs documentation
|
138 |
+
/site
|
139 |
+
|
140 |
+
# mypy
|
141 |
+
.mypy_cache/
|
142 |
+
.dmypy.json
|
143 |
+
dmypy.json
|
144 |
+
|
145 |
+
# Pyre type checker
|
146 |
+
.pyre/
|
147 |
+
|
148 |
+
# pytype static type analyzer
|
149 |
+
.pytype/
|
150 |
+
|
151 |
+
# Cython debug symbols
|
152 |
+
cython_debug/
|
153 |
+
|
154 |
+
# PyCharm
|
155 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
156 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
157 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
158 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
159 |
+
#.idea/
|
160 |
+
|
161 |
+
# Ignore
|
162 |
+
sub_tra.*
|
163 |
+
sub_ori.*
|
164 |
+
SPEAKER_00.*
|
165 |
+
SPEAKER_01.*
|
166 |
+
SPEAKER_02.*
|
167 |
+
SPEAKER_03.*
|
168 |
+
SPEAKER_04.*
|
169 |
+
SPEAKER_05.*
|
170 |
+
SPEAKER_06.*
|
171 |
+
SPEAKER_07.*
|
172 |
+
SPEAKER_08.*
|
173 |
+
SPEAKER_09.*
|
174 |
+
SPEAKER_10.*
|
175 |
+
SPEAKER_11.*
|
176 |
+
task_subtitle.*
|
177 |
+
*.mp3
|
178 |
+
*.mp4
|
179 |
+
*.ogg
|
180 |
+
*.wav
|
181 |
+
*.mkv
|
182 |
+
*.webm
|
183 |
+
*.avi
|
184 |
+
*.mpg
|
185 |
+
*.mov
|
186 |
+
*.ogv
|
187 |
+
*.wmv
|
188 |
+
test.py
|
189 |
+
list.txt
|
190 |
+
text_preprocessor.txt
|
191 |
+
text_translation.txt
|
192 |
+
*.srt
|
193 |
+
*.vtt
|
194 |
+
*.tsv
|
195 |
+
*.aud
|
196 |
+
*.ass
|
197 |
+
*.pt
|
198 |
+
.vscode/
|
199 |
+
mdx_models/*.onnx
|
200 |
+
_XTTS_/
|
201 |
+
downloads/
|
202 |
+
logs/
|
203 |
+
weights/
|
204 |
+
clean_song_output/
|
205 |
+
audio2/
|
206 |
+
audio/
|
207 |
+
outputs/
|
208 |
+
processed/
|
209 |
+
OPENVOICE_MODELS/
|
210 |
+
PIPER_MODELS/
|
211 |
+
WHISPER_MODELS/
|
212 |
+
whisper_api_audio_parts/
|
213 |
+
uroman/
|
214 |
+
pdf_images/
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
SoniTranslate_Colab.ipynb
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"nbformat": 4,
|
3 |
+
"nbformat_minor": 0,
|
4 |
+
"metadata": {
|
5 |
+
"colab": {
|
6 |
+
"provenance": [],
|
7 |
+
"gpuType": "T4",
|
8 |
+
"include_colab_link": true
|
9 |
+
},
|
10 |
+
"kernelspec": {
|
11 |
+
"name": "python3",
|
12 |
+
"display_name": "Python 3"
|
13 |
+
},
|
14 |
+
"language_info": {
|
15 |
+
"name": "python"
|
16 |
+
},
|
17 |
+
"accelerator": "GPU"
|
18 |
+
},
|
19 |
+
"cells": [
|
20 |
+
{
|
21 |
+
"cell_type": "markdown",
|
22 |
+
"metadata": {
|
23 |
+
"id": "view-in-github",
|
24 |
+
"colab_type": "text"
|
25 |
+
},
|
26 |
+
"source": [
|
27 |
+
"<a href=\"https://colab.research.google.com/github/R3gm/SoniTranslate/blob/main/SoniTranslate_Colab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
28 |
+
]
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"cell_type": "markdown",
|
32 |
+
"source": [
|
33 |
+
"# SoniTranslate\n",
|
34 |
+
"\n",
|
35 |
+
"| Description | Link |\n",
|
36 |
+
"| ----------- | ---- |\n",
|
37 |
+
"| 🎉 Repository | [![GitHub Repository](https://img.shields.io/badge/GitHub-Repository-black?style=flat-square&logo=github)](https://github.com/R3gm/SoniTranslate/) |\n",
|
38 |
+
"| 🚀 Online Demo in HF | [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/r3gm/SoniTranslate_translate_audio_of_a_video_content) |\n",
|
39 |
+
"\n",
|
40 |
+
"\n"
|
41 |
+
],
|
42 |
+
"metadata": {
|
43 |
+
"id": "8lw0EgLex-YZ"
|
44 |
+
}
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"cell_type": "code",
|
48 |
+
"execution_count": null,
|
49 |
+
"metadata": {
|
50 |
+
"id": "LUgwm0rfx0_J",
|
51 |
+
"cellView": "form"
|
52 |
+
},
|
53 |
+
"outputs": [],
|
54 |
+
"source": [
|
55 |
+
"# @title Install requirements for SoniTranslate\n",
|
56 |
+
"!git clone https://github.com/r3gm/SoniTranslate.git\n",
|
57 |
+
"%cd SoniTranslate\n",
|
58 |
+
"\n",
|
59 |
+
"!apt install git-lfs\n",
|
60 |
+
"!git lfs install\n",
|
61 |
+
"\n",
|
62 |
+
"!sed -i 's|git+https://github.com/R3gm/whisperX.git@cuda_11_8|git+https://github.com/R3gm/whisperX.git@cuda_12_x|' requirements_base.txt\n",
|
63 |
+
"!pip install -q -r requirements_base.txt\n",
|
64 |
+
"!pip install -q -r requirements_extra.txt\n",
|
65 |
+
"!pip install -q ort-nightly-gpu --index-url=https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ort-cuda-12-nightly/pypi/simple/\n",
|
66 |
+
"\n",
|
67 |
+
"Install_PIPER_TTS = True # @param {type:\"boolean\"}\n",
|
68 |
+
"\n",
|
69 |
+
"if Install_PIPER_TTS:\n",
|
70 |
+
" !pip install -q piper-tts==1.2.0\n",
|
71 |
+
"\n",
|
72 |
+
"Install_Coqui_XTTS = True # @param {type:\"boolean\"}\n",
|
73 |
+
"\n",
|
74 |
+
"if Install_Coqui_XTTS:\n",
|
75 |
+
" !pip install -q -r requirements_xtts.txt\n",
|
76 |
+
" !pip install -q TTS==0.21.1 --no-deps"
|
77 |
+
]
|
78 |
+
},
|
79 |
+
{
|
80 |
+
"cell_type": "markdown",
|
81 |
+
"source": [
|
82 |
+
"One important step is to accept the license agreement for using Pyannote. You need to have an account on Hugging Face and `accept the license to use the models`: https://huggingface.co/pyannote/speaker-diarization and https://huggingface.co/pyannote/segmentation\n",
|
83 |
+
"\n",
|
84 |
+
"\n",
|
85 |
+
"\n",
|
86 |
+
"\n",
|
87 |
+
"Get your KEY TOKEN here: https://hf.co/settings/tokens"
|
88 |
+
],
|
89 |
+
"metadata": {
|
90 |
+
"id": "LTaTstXPXNg2"
|
91 |
+
}
|
92 |
+
},
|
93 |
+
{
|
94 |
+
"cell_type": "code",
|
95 |
+
"source": [
|
96 |
+
"#@markdown # `RUN THE WEB APP`\n",
|
97 |
+
"YOUR_HF_TOKEN = \"\" #@param {type:'string'}\n",
|
98 |
+
"%env YOUR_HF_TOKEN={YOUR_HF_TOKEN}\n",
|
99 |
+
"theme = \"Taithrah/Minimal\" # @param [\"Taithrah/Minimal\", \"aliabid94/new-theme\", \"gstaff/xkcd\", \"ParityError/LimeFace\", \"abidlabs/pakistan\", \"rottenlittlecreature/Moon_Goblin\", \"ysharma/llamas\", \"gradio/dracula_revamped\"]\n",
|
100 |
+
"interface_language = \"english\" # @param ['arabic', 'azerbaijani', 'chinese_zh_cn', 'english', 'french', 'german', 'hindi', 'indonesian', 'italian', 'japanese', 'korean', 'marathi', 'polish', 'portuguese', 'russian', 'spanish', 'swedish', 'turkish', 'ukrainian', 'vietnamese']\n",
|
101 |
+
"verbosity_level = \"info\" # @param [\"debug\", \"info\", \"warning\", \"error\", \"critical\"]\n",
|
102 |
+
"\n",
|
103 |
+
"\n",
|
104 |
+
"%cd /content/SoniTranslate\n",
|
105 |
+
"!python app_rvc.py --theme {theme} --verbosity_level {verbosity_level} --language {interface_language} --public_url"
|
106 |
+
],
|
107 |
+
"metadata": {
|
108 |
+
"id": "XkhXfaFw4R4J",
|
109 |
+
"cellView": "form"
|
110 |
+
},
|
111 |
+
"execution_count": null,
|
112 |
+
"outputs": []
|
113 |
+
},
|
114 |
+
{
|
115 |
+
"cell_type": "markdown",
|
116 |
+
"source": [
|
117 |
+
"Open the `public URL` when it appears"
|
118 |
+
],
|
119 |
+
"metadata": {
|
120 |
+
"id": "KJW3KrhZJh0u"
|
121 |
+
}
|
122 |
+
}
|
123 |
+
]
|
124 |
+
}
|
app.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.system("python app_rvc.py --language french --theme aliabid94/new-theme")
|
app_rvc.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
assets/logo.jpeg
ADDED
docs/windows_install.md
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Install Locally Windows
|
2 |
+
|
3 |
+
### Before You Start
|
4 |
+
|
5 |
+
Before you start installing and using SoniTranslate, there are a few things you need to do:
|
6 |
+
|
7 |
+
1. Install Microsoft Visual C++ Build Tools, MSVC and Windows 10 SDK:
|
8 |
+
|
9 |
+
* Go to the [Visual Studio downloads page](https://visualstudio.microsoft.com/visual-cpp-build-tools/); Or maybe you already have **Visual Studio Installer**? Open it. If you have it already click modify.
|
10 |
+
* Download and install the "Build Tools for Visual Studio" if you don't have it.
|
11 |
+
* During installation, under "Workloads", select "C++ build tools" and ensure the latest versions of "MSVCv142 - VS 2019 C++ x64/x86 build tools" and "Windows 10 SDK" are selected ("Windows 11 SDK" if you are using Windows 11); OR go to individual components and find those two listed.
|
12 |
+
* Complete the installation.
|
13 |
+
|
14 |
+
2. Verify the NVIDIA driver on Windows using the command line:
|
15 |
+
|
16 |
+
* **Open Command Prompt:** Press `Win + R`, type `cmd`, then press `Enter`.
|
17 |
+
|
18 |
+
* **Type the command:** `nvidia-smi` and press `Enter`.
|
19 |
+
|
20 |
+
* **Look for "CUDA Version"** in the output.
|
21 |
+
|
22 |
+
```
|
23 |
+
+-----------------------------------------------------------------------------+
|
24 |
+
| NVIDIA-SMI 522.25 Driver Version: 522.25 CUDA Version: 11.8 |
|
25 |
+
|-------------------------------+----------------------+----------------------+
|
26 |
+
```
|
27 |
+
|
28 |
+
3. If you see that your CUDA version is less than 11.8, you should update your NVIDIA driver. Visit the NVIDIA website's driver download page (https://www.nvidia.com/Download/index.aspx) and enter your graphics card information.
|
29 |
+
|
30 |
+
4. Accept the license agreement for using Pyannote. You need to have an account on Hugging Face and `accept the license to use the models`: https://huggingface.co/pyannote/speaker-diarization and https://huggingface.co/pyannote/segmentation
|
31 |
+
5. Create a [huggingface token](https://huggingface.co/settings/tokens). Hugging Face is a natural language processing platform that provides access to state-of-the-art models and tools. You will need to create a token in order to use some of the automatic model download features in SoniTranslate. Follow the instructions on the Hugging Face website to create a token.
|
32 |
+
6. Install [Anaconda](https://www.anaconda.com/) or [Miniconda](https://docs.anaconda.com/free/miniconda/miniconda-install/). Anaconda is a free and open-source distribution of Python and R. It includes a package manager called conda that makes it easy to install and manage Python environments and packages. Follow the instructions on the Anaconda website to download and install Anaconda on your system.
|
33 |
+
7. Install Git for your system. Git is a version control system that helps you track changes to your code and collaborate with other developers. You can install Git with Anaconda by running `conda install -c anaconda git -y` in your terminal (Do this after step 1 in the following section.). If you have trouble installing Git via Anaconda, you can use the following link instead:
|
34 |
+
- [Git for Windows](https://git-scm.com/download/win)
|
35 |
+
|
36 |
+
Once you have completed these steps, you will be ready to install SoniTranslate.
|
37 |
+
|
38 |
+
### Getting Started
|
39 |
+
|
40 |
+
To install SoniTranslate, follow these steps:
|
41 |
+
|
42 |
+
1. Create a suitable anaconda environment for SoniTranslate and activate it:
|
43 |
+
|
44 |
+
```
|
45 |
+
conda create -n sonitr python=3.10 -y
|
46 |
+
conda activate sonitr
|
47 |
+
```
|
48 |
+
|
49 |
+
2. Clone this github repository and navigate to it:
|
50 |
+
```
|
51 |
+
git clone https://github.com/r3gm/SoniTranslate.git
|
52 |
+
cd SoniTranslate
|
53 |
+
```
|
54 |
+
3. Install CUDA Toolkit 11.8.0
|
55 |
+
|
56 |
+
```
|
57 |
+
conda install -c "nvidia/label/cuda-11.8.0" cuda-toolkit -y
|
58 |
+
```
|
59 |
+
|
60 |
+
4. Install PyTorch using conda
|
61 |
+
```
|
62 |
+
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia -y
|
63 |
+
```
|
64 |
+
|
65 |
+
5. Install required packages:
|
66 |
+
|
67 |
+
```
|
68 |
+
pip install -r requirements_base.txt -v
|
69 |
+
pip install -r requirements_extra.txt -v
|
70 |
+
pip install onnxruntime-gpu
|
71 |
+
```
|
72 |
+
|
73 |
+
6. Install [ffmpeg](https://ffmpeg.org/download.html). FFmpeg is a free software project that produces libraries and programs for handling multimedia data. You will need it to process audio and video files. You can install ffmpeg with Anaconda by running `conda install -y ffmpeg` in your terminal (recommended). If you have trouble installing ffmpeg via Anaconda, you can use the following link instead: (https://ffmpeg.org/ffmpeg.html). Once it is installed, make sure it is in your PATH by running `ffmpeg -h` in your terminal. If you don't get an error message, you're good to go.
|
74 |
+
|
75 |
+
7. Optional install:
|
76 |
+
|
77 |
+
After installing FFmpeg, you can install these optional packages.
|
78 |
+
|
79 |
+
[Coqui XTTS](https://github.com/coqui-ai/TTS) is a text-to-speech (TTS) model that lets you generate realistic voices in different languages. It can clone voices with just a short audio clip, even speak in a different language! It's like having a personal voice mimic for any text you need spoken.
|
80 |
+
|
81 |
+
```
|
82 |
+
pip install -q -r requirements_xtts.txt
|
83 |
+
pip install -q TTS==0.21.1 --no-deps
|
84 |
+
```
|
85 |
+
|
86 |
+
[Piper TTS](https://github.com/rhasspy/piper) is a fast, local neural text to speech system that sounds great and is optimized for the Raspberry Pi 4. Piper is used in a variety of projects. Voices are trained with VITS and exported to the onnxruntime.
|
87 |
+
|
88 |
+
🚧 For Windows users, it's important to note that the Python module piper-tts is not fully supported on this operating system. While it works smoothly on Linux, Windows compatibility is currently experimental. If you still wish to install it on Windows, you can follow this experimental method:
|
89 |
+
|
90 |
+
```
|
91 |
+
pip install https://github.com/R3gm/piper-phonemize/releases/download/1.2.0/piper_phonemize-1.2.0-cp310-cp310-win_amd64.whl
|
92 |
+
pip install sherpa-onnx==1.9.12
|
93 |
+
pip install piper-tts==1.2.0 --no-deps
|
94 |
+
```
|
95 |
+
|
96 |
+
8. Setting your [Hugging Face token](https://huggingface.co/settings/tokens) as an environment variable in quotes:
|
97 |
+
|
98 |
+
```
|
99 |
+
conda env config vars set YOUR_HF_TOKEN="YOUR_HUGGING_FACE_TOKEN_HERE"
|
100 |
+
conda deactivate
|
101 |
+
```
|
102 |
+
|
103 |
+
|
104 |
+
### Running SoniTranslate
|
105 |
+
|
106 |
+
To run SoniTranslate locally, make sure the `sonitr` conda environment is active:
|
107 |
+
|
108 |
+
```
|
109 |
+
conda activate sonitr
|
110 |
+
```
|
111 |
+
|
112 |
+
Then navigate to the `SoniTranslate` folder and run either the `app_rvc.py`
|
113 |
+
|
114 |
+
```
|
115 |
+
python app_rvc.py
|
116 |
+
```
|
117 |
+
When the `local URL` `http://127.0.0.1:7860` is displayed in the terminal, simply open this URL in your web browser to access the SoniTranslate interface.
|
118 |
+
|
119 |
+
### Stop and close SoniTranslate.
|
120 |
+
|
121 |
+
In most environments, you can stop the execution by pressing Ctrl+C in the terminal where you launched the script `app_rvc.py`. This will interrupt the program and stop the Gradio app.
|
122 |
+
To deactivate the Conda environment, you can use the following command:
|
123 |
+
|
124 |
+
```
|
125 |
+
conda deactivate
|
126 |
+
```
|
127 |
+
|
128 |
+
This will deactivate the currently active Conda environment sonitr, and you'll return to the base environment or the global Python environment.
|
129 |
+
|
130 |
+
### Starting Over
|
131 |
+
|
132 |
+
If you need to start over from scratch, you can delete the `SoniTranslate` folder and remove the `sonitr` conda environment with the following set of commands:
|
133 |
+
|
134 |
+
```
|
135 |
+
conda deactivate
|
136 |
+
conda env remove -n sonitr
|
137 |
+
```
|
138 |
+
|
139 |
+
With the `sonitr` environment removed, you can start over with a fresh installation.
|
140 |
+
|
141 |
+
### Notes
|
142 |
+
- To use OpenAI's GPT API for translation, set up your OpenAI API key as an environment variable in quotes:
|
143 |
+
|
144 |
+
```
|
145 |
+
conda activate sonitr
|
146 |
+
conda env config vars set OPENAI_API_KEY="your-api-key-here"
|
147 |
+
conda deactivate
|
148 |
+
```
|
149 |
+
|
150 |
+
- Alternatively, you can install the CUDA Toolkit 11.8.0 directly on your system [CUDA Toolkit 11.8.0](https://developer.nvidia.com/cuda-11-8-0-download-archive).
|
lib/audio.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ffmpeg
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
def load_audio(file, sr):
|
6 |
+
try:
|
7 |
+
# https://github.com/openai/whisper/blob/main/whisper/audio.py#L26
|
8 |
+
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
|
9 |
+
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
|
10 |
+
file = (
|
11 |
+
file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
|
12 |
+
) # To prevent beginners from copying paths with leading or trailing spaces, quotation marks, and line breaks.
|
13 |
+
out, _ = (
|
14 |
+
ffmpeg.input(file, threads=0)
|
15 |
+
.output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr)
|
16 |
+
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
|
17 |
+
)
|
18 |
+
except Exception as e:
|
19 |
+
raise RuntimeError(f"Failed to load audio: {e}")
|
20 |
+
|
21 |
+
return np.frombuffer(out, np.float32).flatten()
|
lib/infer_pack/attentions.py
ADDED
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
from lib.infer_pack import commons
|
9 |
+
from lib.infer_pack import modules
|
10 |
+
from lib.infer_pack.modules import LayerNorm
|
11 |
+
|
12 |
+
|
13 |
+
class Encoder(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
hidden_channels,
|
17 |
+
filter_channels,
|
18 |
+
n_heads,
|
19 |
+
n_layers,
|
20 |
+
kernel_size=1,
|
21 |
+
p_dropout=0.0,
|
22 |
+
window_size=10,
|
23 |
+
**kwargs
|
24 |
+
):
|
25 |
+
super().__init__()
|
26 |
+
self.hidden_channels = hidden_channels
|
27 |
+
self.filter_channels = filter_channels
|
28 |
+
self.n_heads = n_heads
|
29 |
+
self.n_layers = n_layers
|
30 |
+
self.kernel_size = kernel_size
|
31 |
+
self.p_dropout = p_dropout
|
32 |
+
self.window_size = window_size
|
33 |
+
|
34 |
+
self.drop = nn.Dropout(p_dropout)
|
35 |
+
self.attn_layers = nn.ModuleList()
|
36 |
+
self.norm_layers_1 = nn.ModuleList()
|
37 |
+
self.ffn_layers = nn.ModuleList()
|
38 |
+
self.norm_layers_2 = nn.ModuleList()
|
39 |
+
for i in range(self.n_layers):
|
40 |
+
self.attn_layers.append(
|
41 |
+
MultiHeadAttention(
|
42 |
+
hidden_channels,
|
43 |
+
hidden_channels,
|
44 |
+
n_heads,
|
45 |
+
p_dropout=p_dropout,
|
46 |
+
window_size=window_size,
|
47 |
+
)
|
48 |
+
)
|
49 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
50 |
+
self.ffn_layers.append(
|
51 |
+
FFN(
|
52 |
+
hidden_channels,
|
53 |
+
hidden_channels,
|
54 |
+
filter_channels,
|
55 |
+
kernel_size,
|
56 |
+
p_dropout=p_dropout,
|
57 |
+
)
|
58 |
+
)
|
59 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
60 |
+
|
61 |
+
def forward(self, x, x_mask):
|
62 |
+
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
63 |
+
x = x * x_mask
|
64 |
+
for i in range(self.n_layers):
|
65 |
+
y = self.attn_layers[i](x, x, attn_mask)
|
66 |
+
y = self.drop(y)
|
67 |
+
x = self.norm_layers_1[i](x + y)
|
68 |
+
|
69 |
+
y = self.ffn_layers[i](x, x_mask)
|
70 |
+
y = self.drop(y)
|
71 |
+
x = self.norm_layers_2[i](x + y)
|
72 |
+
x = x * x_mask
|
73 |
+
return x
|
74 |
+
|
75 |
+
|
76 |
+
class Decoder(nn.Module):
|
77 |
+
def __init__(
|
78 |
+
self,
|
79 |
+
hidden_channels,
|
80 |
+
filter_channels,
|
81 |
+
n_heads,
|
82 |
+
n_layers,
|
83 |
+
kernel_size=1,
|
84 |
+
p_dropout=0.0,
|
85 |
+
proximal_bias=False,
|
86 |
+
proximal_init=True,
|
87 |
+
**kwargs
|
88 |
+
):
|
89 |
+
super().__init__()
|
90 |
+
self.hidden_channels = hidden_channels
|
91 |
+
self.filter_channels = filter_channels
|
92 |
+
self.n_heads = n_heads
|
93 |
+
self.n_layers = n_layers
|
94 |
+
self.kernel_size = kernel_size
|
95 |
+
self.p_dropout = p_dropout
|
96 |
+
self.proximal_bias = proximal_bias
|
97 |
+
self.proximal_init = proximal_init
|
98 |
+
|
99 |
+
self.drop = nn.Dropout(p_dropout)
|
100 |
+
self.self_attn_layers = nn.ModuleList()
|
101 |
+
self.norm_layers_0 = nn.ModuleList()
|
102 |
+
self.encdec_attn_layers = nn.ModuleList()
|
103 |
+
self.norm_layers_1 = nn.ModuleList()
|
104 |
+
self.ffn_layers = nn.ModuleList()
|
105 |
+
self.norm_layers_2 = nn.ModuleList()
|
106 |
+
for i in range(self.n_layers):
|
107 |
+
self.self_attn_layers.append(
|
108 |
+
MultiHeadAttention(
|
109 |
+
hidden_channels,
|
110 |
+
hidden_channels,
|
111 |
+
n_heads,
|
112 |
+
p_dropout=p_dropout,
|
113 |
+
proximal_bias=proximal_bias,
|
114 |
+
proximal_init=proximal_init,
|
115 |
+
)
|
116 |
+
)
|
117 |
+
self.norm_layers_0.append(LayerNorm(hidden_channels))
|
118 |
+
self.encdec_attn_layers.append(
|
119 |
+
MultiHeadAttention(
|
120 |
+
hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
|
121 |
+
)
|
122 |
+
)
|
123 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
124 |
+
self.ffn_layers.append(
|
125 |
+
FFN(
|
126 |
+
hidden_channels,
|
127 |
+
hidden_channels,
|
128 |
+
filter_channels,
|
129 |
+
kernel_size,
|
130 |
+
p_dropout=p_dropout,
|
131 |
+
causal=True,
|
132 |
+
)
|
133 |
+
)
|
134 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
135 |
+
|
136 |
+
def forward(self, x, x_mask, h, h_mask):
|
137 |
+
"""
|
138 |
+
x: decoder input
|
139 |
+
h: encoder output
|
140 |
+
"""
|
141 |
+
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
|
142 |
+
device=x.device, dtype=x.dtype
|
143 |
+
)
|
144 |
+
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
145 |
+
x = x * x_mask
|
146 |
+
for i in range(self.n_layers):
|
147 |
+
y = self.self_attn_layers[i](x, x, self_attn_mask)
|
148 |
+
y = self.drop(y)
|
149 |
+
x = self.norm_layers_0[i](x + y)
|
150 |
+
|
151 |
+
y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
|
152 |
+
y = self.drop(y)
|
153 |
+
x = self.norm_layers_1[i](x + y)
|
154 |
+
|
155 |
+
y = self.ffn_layers[i](x, x_mask)
|
156 |
+
y = self.drop(y)
|
157 |
+
x = self.norm_layers_2[i](x + y)
|
158 |
+
x = x * x_mask
|
159 |
+
return x
|
160 |
+
|
161 |
+
|
162 |
+
class MultiHeadAttention(nn.Module):
|
163 |
+
def __init__(
|
164 |
+
self,
|
165 |
+
channels,
|
166 |
+
out_channels,
|
167 |
+
n_heads,
|
168 |
+
p_dropout=0.0,
|
169 |
+
window_size=None,
|
170 |
+
heads_share=True,
|
171 |
+
block_length=None,
|
172 |
+
proximal_bias=False,
|
173 |
+
proximal_init=False,
|
174 |
+
):
|
175 |
+
super().__init__()
|
176 |
+
assert channels % n_heads == 0
|
177 |
+
|
178 |
+
self.channels = channels
|
179 |
+
self.out_channels = out_channels
|
180 |
+
self.n_heads = n_heads
|
181 |
+
self.p_dropout = p_dropout
|
182 |
+
self.window_size = window_size
|
183 |
+
self.heads_share = heads_share
|
184 |
+
self.block_length = block_length
|
185 |
+
self.proximal_bias = proximal_bias
|
186 |
+
self.proximal_init = proximal_init
|
187 |
+
self.attn = None
|
188 |
+
|
189 |
+
self.k_channels = channels // n_heads
|
190 |
+
self.conv_q = nn.Conv1d(channels, channels, 1)
|
191 |
+
self.conv_k = nn.Conv1d(channels, channels, 1)
|
192 |
+
self.conv_v = nn.Conv1d(channels, channels, 1)
|
193 |
+
self.conv_o = nn.Conv1d(channels, out_channels, 1)
|
194 |
+
self.drop = nn.Dropout(p_dropout)
|
195 |
+
|
196 |
+
if window_size is not None:
|
197 |
+
n_heads_rel = 1 if heads_share else n_heads
|
198 |
+
rel_stddev = self.k_channels**-0.5
|
199 |
+
self.emb_rel_k = nn.Parameter(
|
200 |
+
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
201 |
+
* rel_stddev
|
202 |
+
)
|
203 |
+
self.emb_rel_v = nn.Parameter(
|
204 |
+
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
205 |
+
* rel_stddev
|
206 |
+
)
|
207 |
+
|
208 |
+
nn.init.xavier_uniform_(self.conv_q.weight)
|
209 |
+
nn.init.xavier_uniform_(self.conv_k.weight)
|
210 |
+
nn.init.xavier_uniform_(self.conv_v.weight)
|
211 |
+
if proximal_init:
|
212 |
+
with torch.no_grad():
|
213 |
+
self.conv_k.weight.copy_(self.conv_q.weight)
|
214 |
+
self.conv_k.bias.copy_(self.conv_q.bias)
|
215 |
+
|
216 |
+
def forward(self, x, c, attn_mask=None):
|
217 |
+
q = self.conv_q(x)
|
218 |
+
k = self.conv_k(c)
|
219 |
+
v = self.conv_v(c)
|
220 |
+
|
221 |
+
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
222 |
+
|
223 |
+
x = self.conv_o(x)
|
224 |
+
return x
|
225 |
+
|
226 |
+
def attention(self, query, key, value, mask=None):
|
227 |
+
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
228 |
+
b, d, t_s, t_t = (*key.size(), query.size(2))
|
229 |
+
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
230 |
+
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
231 |
+
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
232 |
+
|
233 |
+
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
|
234 |
+
if self.window_size is not None:
|
235 |
+
assert (
|
236 |
+
t_s == t_t
|
237 |
+
), "Relative attention is only available for self-attention."
|
238 |
+
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
239 |
+
rel_logits = self._matmul_with_relative_keys(
|
240 |
+
query / math.sqrt(self.k_channels), key_relative_embeddings
|
241 |
+
)
|
242 |
+
scores_local = self._relative_position_to_absolute_position(rel_logits)
|
243 |
+
scores = scores + scores_local
|
244 |
+
if self.proximal_bias:
|
245 |
+
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
246 |
+
scores = scores + self._attention_bias_proximal(t_s).to(
|
247 |
+
device=scores.device, dtype=scores.dtype
|
248 |
+
)
|
249 |
+
if mask is not None:
|
250 |
+
scores = scores.masked_fill(mask == 0, -1e4)
|
251 |
+
if self.block_length is not None:
|
252 |
+
assert (
|
253 |
+
t_s == t_t
|
254 |
+
), "Local attention is only available for self-attention."
|
255 |
+
block_mask = (
|
256 |
+
torch.ones_like(scores)
|
257 |
+
.triu(-self.block_length)
|
258 |
+
.tril(self.block_length)
|
259 |
+
)
|
260 |
+
scores = scores.masked_fill(block_mask == 0, -1e4)
|
261 |
+
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
262 |
+
p_attn = self.drop(p_attn)
|
263 |
+
output = torch.matmul(p_attn, value)
|
264 |
+
if self.window_size is not None:
|
265 |
+
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
266 |
+
value_relative_embeddings = self._get_relative_embeddings(
|
267 |
+
self.emb_rel_v, t_s
|
268 |
+
)
|
269 |
+
output = output + self._matmul_with_relative_values(
|
270 |
+
relative_weights, value_relative_embeddings
|
271 |
+
)
|
272 |
+
output = (
|
273 |
+
output.transpose(2, 3).contiguous().view(b, d, t_t)
|
274 |
+
) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
275 |
+
return output, p_attn
|
276 |
+
|
277 |
+
def _matmul_with_relative_values(self, x, y):
|
278 |
+
"""
|
279 |
+
x: [b, h, l, m]
|
280 |
+
y: [h or 1, m, d]
|
281 |
+
ret: [b, h, l, d]
|
282 |
+
"""
|
283 |
+
ret = torch.matmul(x, y.unsqueeze(0))
|
284 |
+
return ret
|
285 |
+
|
286 |
+
def _matmul_with_relative_keys(self, x, y):
|
287 |
+
"""
|
288 |
+
x: [b, h, l, d]
|
289 |
+
y: [h or 1, m, d]
|
290 |
+
ret: [b, h, l, m]
|
291 |
+
"""
|
292 |
+
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
293 |
+
return ret
|
294 |
+
|
295 |
+
def _get_relative_embeddings(self, relative_embeddings, length):
|
296 |
+
max_relative_position = 2 * self.window_size + 1
|
297 |
+
# Pad first before slice to avoid using cond ops.
|
298 |
+
pad_length = max(length - (self.window_size + 1), 0)
|
299 |
+
slice_start_position = max((self.window_size + 1) - length, 0)
|
300 |
+
slice_end_position = slice_start_position + 2 * length - 1
|
301 |
+
if pad_length > 0:
|
302 |
+
padded_relative_embeddings = F.pad(
|
303 |
+
relative_embeddings,
|
304 |
+
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
|
305 |
+
)
|
306 |
+
else:
|
307 |
+
padded_relative_embeddings = relative_embeddings
|
308 |
+
used_relative_embeddings = padded_relative_embeddings[
|
309 |
+
:, slice_start_position:slice_end_position
|
310 |
+
]
|
311 |
+
return used_relative_embeddings
|
312 |
+
|
313 |
+
def _relative_position_to_absolute_position(self, x):
|
314 |
+
"""
|
315 |
+
x: [b, h, l, 2*l-1]
|
316 |
+
ret: [b, h, l, l]
|
317 |
+
"""
|
318 |
+
batch, heads, length, _ = x.size()
|
319 |
+
# Concat columns of pad to shift from relative to absolute indexing.
|
320 |
+
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
|
321 |
+
|
322 |
+
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
323 |
+
x_flat = x.view([batch, heads, length * 2 * length])
|
324 |
+
x_flat = F.pad(
|
325 |
+
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
|
326 |
+
)
|
327 |
+
|
328 |
+
# Reshape and slice out the padded elements.
|
329 |
+
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
|
330 |
+
:, :, :length, length - 1 :
|
331 |
+
]
|
332 |
+
return x_final
|
333 |
+
|
334 |
+
def _absolute_position_to_relative_position(self, x):
|
335 |
+
"""
|
336 |
+
x: [b, h, l, l]
|
337 |
+
ret: [b, h, l, 2*l-1]
|
338 |
+
"""
|
339 |
+
batch, heads, length, _ = x.size()
|
340 |
+
# padd along column
|
341 |
+
x = F.pad(
|
342 |
+
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
|
343 |
+
)
|
344 |
+
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
345 |
+
# add 0's in the beginning that will skew the elements after reshape
|
346 |
+
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
347 |
+
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
348 |
+
return x_final
|
349 |
+
|
350 |
+
def _attention_bias_proximal(self, length):
|
351 |
+
"""Bias for self-attention to encourage attention to close positions.
|
352 |
+
Args:
|
353 |
+
length: an integer scalar.
|
354 |
+
Returns:
|
355 |
+
a Tensor with shape [1, 1, length, length]
|
356 |
+
"""
|
357 |
+
r = torch.arange(length, dtype=torch.float32)
|
358 |
+
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
359 |
+
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
360 |
+
|
361 |
+
|
362 |
+
class FFN(nn.Module):
|
363 |
+
def __init__(
|
364 |
+
self,
|
365 |
+
in_channels,
|
366 |
+
out_channels,
|
367 |
+
filter_channels,
|
368 |
+
kernel_size,
|
369 |
+
p_dropout=0.0,
|
370 |
+
activation=None,
|
371 |
+
causal=False,
|
372 |
+
):
|
373 |
+
super().__init__()
|
374 |
+
self.in_channels = in_channels
|
375 |
+
self.out_channels = out_channels
|
376 |
+
self.filter_channels = filter_channels
|
377 |
+
self.kernel_size = kernel_size
|
378 |
+
self.p_dropout = p_dropout
|
379 |
+
self.activation = activation
|
380 |
+
self.causal = causal
|
381 |
+
|
382 |
+
if causal:
|
383 |
+
self.padding = self._causal_padding
|
384 |
+
else:
|
385 |
+
self.padding = self._same_padding
|
386 |
+
|
387 |
+
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
|
388 |
+
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
|
389 |
+
self.drop = nn.Dropout(p_dropout)
|
390 |
+
|
391 |
+
def forward(self, x, x_mask):
|
392 |
+
x = self.conv_1(self.padding(x * x_mask))
|
393 |
+
if self.activation == "gelu":
|
394 |
+
x = x * torch.sigmoid(1.702 * x)
|
395 |
+
else:
|
396 |
+
x = torch.relu(x)
|
397 |
+
x = self.drop(x)
|
398 |
+
x = self.conv_2(self.padding(x * x_mask))
|
399 |
+
return x * x_mask
|
400 |
+
|
401 |
+
def _causal_padding(self, x):
|
402 |
+
if self.kernel_size == 1:
|
403 |
+
return x
|
404 |
+
pad_l = self.kernel_size - 1
|
405 |
+
pad_r = 0
|
406 |
+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
407 |
+
x = F.pad(x, commons.convert_pad_shape(padding))
|
408 |
+
return x
|
409 |
+
|
410 |
+
def _same_padding(self, x):
|
411 |
+
if self.kernel_size == 1:
|
412 |
+
return x
|
413 |
+
pad_l = (self.kernel_size - 1) // 2
|
414 |
+
pad_r = self.kernel_size // 2
|
415 |
+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
416 |
+
x = F.pad(x, commons.convert_pad_shape(padding))
|
417 |
+
return x
|
lib/infer_pack/commons.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
|
8 |
+
def init_weights(m, mean=0.0, std=0.01):
|
9 |
+
classname = m.__class__.__name__
|
10 |
+
if classname.find("Conv") != -1:
|
11 |
+
m.weight.data.normal_(mean, std)
|
12 |
+
|
13 |
+
|
14 |
+
def get_padding(kernel_size, dilation=1):
|
15 |
+
return int((kernel_size * dilation - dilation) / 2)
|
16 |
+
|
17 |
+
|
18 |
+
def convert_pad_shape(pad_shape):
|
19 |
+
l = pad_shape[::-1]
|
20 |
+
pad_shape = [item for sublist in l for item in sublist]
|
21 |
+
return pad_shape
|
22 |
+
|
23 |
+
|
24 |
+
def kl_divergence(m_p, logs_p, m_q, logs_q):
|
25 |
+
"""KL(P||Q)"""
|
26 |
+
kl = (logs_q - logs_p) - 0.5
|
27 |
+
kl += (
|
28 |
+
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
|
29 |
+
)
|
30 |
+
return kl
|
31 |
+
|
32 |
+
|
33 |
+
def rand_gumbel(shape):
|
34 |
+
"""Sample from the Gumbel distribution, protect from overflows."""
|
35 |
+
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
|
36 |
+
return -torch.log(-torch.log(uniform_samples))
|
37 |
+
|
38 |
+
|
39 |
+
def rand_gumbel_like(x):
|
40 |
+
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
|
41 |
+
return g
|
42 |
+
|
43 |
+
|
44 |
+
def slice_segments(x, ids_str, segment_size=4):
|
45 |
+
ret = torch.zeros_like(x[:, :, :segment_size])
|
46 |
+
for i in range(x.size(0)):
|
47 |
+
idx_str = ids_str[i]
|
48 |
+
idx_end = idx_str + segment_size
|
49 |
+
ret[i] = x[i, :, idx_str:idx_end]
|
50 |
+
return ret
|
51 |
+
|
52 |
+
|
53 |
+
def slice_segments2(x, ids_str, segment_size=4):
|
54 |
+
ret = torch.zeros_like(x[:, :segment_size])
|
55 |
+
for i in range(x.size(0)):
|
56 |
+
idx_str = ids_str[i]
|
57 |
+
idx_end = idx_str + segment_size
|
58 |
+
ret[i] = x[i, idx_str:idx_end]
|
59 |
+
return ret
|
60 |
+
|
61 |
+
|
62 |
+
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
63 |
+
b, d, t = x.size()
|
64 |
+
if x_lengths is None:
|
65 |
+
x_lengths = t
|
66 |
+
ids_str_max = x_lengths - segment_size + 1
|
67 |
+
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
|
68 |
+
ret = slice_segments(x, ids_str, segment_size)
|
69 |
+
return ret, ids_str
|
70 |
+
|
71 |
+
|
72 |
+
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
|
73 |
+
position = torch.arange(length, dtype=torch.float)
|
74 |
+
num_timescales = channels // 2
|
75 |
+
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
|
76 |
+
num_timescales - 1
|
77 |
+
)
|
78 |
+
inv_timescales = min_timescale * torch.exp(
|
79 |
+
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
|
80 |
+
)
|
81 |
+
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
|
82 |
+
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
|
83 |
+
signal = F.pad(signal, [0, 0, 0, channels % 2])
|
84 |
+
signal = signal.view(1, channels, length)
|
85 |
+
return signal
|
86 |
+
|
87 |
+
|
88 |
+
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
|
89 |
+
b, channels, length = x.size()
|
90 |
+
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
91 |
+
return x + signal.to(dtype=x.dtype, device=x.device)
|
92 |
+
|
93 |
+
|
94 |
+
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
|
95 |
+
b, channels, length = x.size()
|
96 |
+
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
97 |
+
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
|
98 |
+
|
99 |
+
|
100 |
+
def subsequent_mask(length):
|
101 |
+
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
|
102 |
+
return mask
|
103 |
+
|
104 |
+
|
105 |
+
@torch.jit.script
|
106 |
+
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
107 |
+
n_channels_int = n_channels[0]
|
108 |
+
in_act = input_a + input_b
|
109 |
+
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
110 |
+
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
111 |
+
acts = t_act * s_act
|
112 |
+
return acts
|
113 |
+
|
114 |
+
|
115 |
+
def convert_pad_shape(pad_shape):
|
116 |
+
l = pad_shape[::-1]
|
117 |
+
pad_shape = [item for sublist in l for item in sublist]
|
118 |
+
return pad_shape
|
119 |
+
|
120 |
+
|
121 |
+
def shift_1d(x):
|
122 |
+
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
|
123 |
+
return x
|
124 |
+
|
125 |
+
|
126 |
+
def sequence_mask(length, max_length=None):
|
127 |
+
if max_length is None:
|
128 |
+
max_length = length.max()
|
129 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
130 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
131 |
+
|
132 |
+
|
133 |
+
def generate_path(duration, mask):
|
134 |
+
"""
|
135 |
+
duration: [b, 1, t_x]
|
136 |
+
mask: [b, 1, t_y, t_x]
|
137 |
+
"""
|
138 |
+
device = duration.device
|
139 |
+
|
140 |
+
b, _, t_y, t_x = mask.shape
|
141 |
+
cum_duration = torch.cumsum(duration, -1)
|
142 |
+
|
143 |
+
cum_duration_flat = cum_duration.view(b * t_x)
|
144 |
+
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
145 |
+
path = path.view(b, t_x, t_y)
|
146 |
+
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
147 |
+
path = path.unsqueeze(1).transpose(2, 3) * mask
|
148 |
+
return path
|
149 |
+
|
150 |
+
|
151 |
+
def clip_grad_value_(parameters, clip_value, norm_type=2):
|
152 |
+
if isinstance(parameters, torch.Tensor):
|
153 |
+
parameters = [parameters]
|
154 |
+
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
155 |
+
norm_type = float(norm_type)
|
156 |
+
if clip_value is not None:
|
157 |
+
clip_value = float(clip_value)
|
158 |
+
|
159 |
+
total_norm = 0
|
160 |
+
for p in parameters:
|
161 |
+
param_norm = p.grad.data.norm(norm_type)
|
162 |
+
total_norm += param_norm.item() ** norm_type
|
163 |
+
if clip_value is not None:
|
164 |
+
p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
165 |
+
total_norm = total_norm ** (1.0 / norm_type)
|
166 |
+
return total_norm
|
lib/infer_pack/models.py
ADDED
@@ -0,0 +1,1142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math, pdb, os
|
2 |
+
from time import time as ttime
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
from lib.infer_pack import modules
|
7 |
+
from lib.infer_pack import attentions
|
8 |
+
from lib.infer_pack import commons
|
9 |
+
from lib.infer_pack.commons import init_weights, get_padding
|
10 |
+
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
11 |
+
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
12 |
+
from lib.infer_pack.commons import init_weights
|
13 |
+
import numpy as np
|
14 |
+
from lib.infer_pack import commons
|
15 |
+
|
16 |
+
|
17 |
+
class TextEncoder256(nn.Module):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
out_channels,
|
21 |
+
hidden_channels,
|
22 |
+
filter_channels,
|
23 |
+
n_heads,
|
24 |
+
n_layers,
|
25 |
+
kernel_size,
|
26 |
+
p_dropout,
|
27 |
+
f0=True,
|
28 |
+
):
|
29 |
+
super().__init__()
|
30 |
+
self.out_channels = out_channels
|
31 |
+
self.hidden_channels = hidden_channels
|
32 |
+
self.filter_channels = filter_channels
|
33 |
+
self.n_heads = n_heads
|
34 |
+
self.n_layers = n_layers
|
35 |
+
self.kernel_size = kernel_size
|
36 |
+
self.p_dropout = p_dropout
|
37 |
+
self.emb_phone = nn.Linear(256, hidden_channels)
|
38 |
+
self.lrelu = nn.LeakyReLU(0.1, inplace=True)
|
39 |
+
if f0 == True:
|
40 |
+
self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
|
41 |
+
self.encoder = attentions.Encoder(
|
42 |
+
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
43 |
+
)
|
44 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
45 |
+
|
46 |
+
def forward(self, phone, pitch, lengths):
|
47 |
+
if pitch == None:
|
48 |
+
x = self.emb_phone(phone)
|
49 |
+
else:
|
50 |
+
x = self.emb_phone(phone) + self.emb_pitch(pitch)
|
51 |
+
x = x * math.sqrt(self.hidden_channels) # [b, t, h]
|
52 |
+
x = self.lrelu(x)
|
53 |
+
x = torch.transpose(x, 1, -1) # [b, h, t]
|
54 |
+
x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
|
55 |
+
x.dtype
|
56 |
+
)
|
57 |
+
x = self.encoder(x * x_mask, x_mask)
|
58 |
+
stats = self.proj(x) * x_mask
|
59 |
+
|
60 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
61 |
+
return m, logs, x_mask
|
62 |
+
|
63 |
+
|
64 |
+
class TextEncoder768(nn.Module):
|
65 |
+
def __init__(
|
66 |
+
self,
|
67 |
+
out_channels,
|
68 |
+
hidden_channels,
|
69 |
+
filter_channels,
|
70 |
+
n_heads,
|
71 |
+
n_layers,
|
72 |
+
kernel_size,
|
73 |
+
p_dropout,
|
74 |
+
f0=True,
|
75 |
+
):
|
76 |
+
super().__init__()
|
77 |
+
self.out_channels = out_channels
|
78 |
+
self.hidden_channels = hidden_channels
|
79 |
+
self.filter_channels = filter_channels
|
80 |
+
self.n_heads = n_heads
|
81 |
+
self.n_layers = n_layers
|
82 |
+
self.kernel_size = kernel_size
|
83 |
+
self.p_dropout = p_dropout
|
84 |
+
self.emb_phone = nn.Linear(768, hidden_channels)
|
85 |
+
self.lrelu = nn.LeakyReLU(0.1, inplace=True)
|
86 |
+
if f0 == True:
|
87 |
+
self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
|
88 |
+
self.encoder = attentions.Encoder(
|
89 |
+
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
90 |
+
)
|
91 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
92 |
+
|
93 |
+
def forward(self, phone, pitch, lengths):
|
94 |
+
if pitch == None:
|
95 |
+
x = self.emb_phone(phone)
|
96 |
+
else:
|
97 |
+
x = self.emb_phone(phone) + self.emb_pitch(pitch)
|
98 |
+
x = x * math.sqrt(self.hidden_channels) # [b, t, h]
|
99 |
+
x = self.lrelu(x)
|
100 |
+
x = torch.transpose(x, 1, -1) # [b, h, t]
|
101 |
+
x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
|
102 |
+
x.dtype
|
103 |
+
)
|
104 |
+
x = self.encoder(x * x_mask, x_mask)
|
105 |
+
stats = self.proj(x) * x_mask
|
106 |
+
|
107 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
108 |
+
return m, logs, x_mask
|
109 |
+
|
110 |
+
|
111 |
+
class ResidualCouplingBlock(nn.Module):
|
112 |
+
def __init__(
|
113 |
+
self,
|
114 |
+
channels,
|
115 |
+
hidden_channels,
|
116 |
+
kernel_size,
|
117 |
+
dilation_rate,
|
118 |
+
n_layers,
|
119 |
+
n_flows=4,
|
120 |
+
gin_channels=0,
|
121 |
+
):
|
122 |
+
super().__init__()
|
123 |
+
self.channels = channels
|
124 |
+
self.hidden_channels = hidden_channels
|
125 |
+
self.kernel_size = kernel_size
|
126 |
+
self.dilation_rate = dilation_rate
|
127 |
+
self.n_layers = n_layers
|
128 |
+
self.n_flows = n_flows
|
129 |
+
self.gin_channels = gin_channels
|
130 |
+
|
131 |
+
self.flows = nn.ModuleList()
|
132 |
+
for i in range(n_flows):
|
133 |
+
self.flows.append(
|
134 |
+
modules.ResidualCouplingLayer(
|
135 |
+
channels,
|
136 |
+
hidden_channels,
|
137 |
+
kernel_size,
|
138 |
+
dilation_rate,
|
139 |
+
n_layers,
|
140 |
+
gin_channels=gin_channels,
|
141 |
+
mean_only=True,
|
142 |
+
)
|
143 |
+
)
|
144 |
+
self.flows.append(modules.Flip())
|
145 |
+
|
146 |
+
def forward(self, x, x_mask, g=None, reverse=False):
|
147 |
+
if not reverse:
|
148 |
+
for flow in self.flows:
|
149 |
+
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
150 |
+
else:
|
151 |
+
for flow in reversed(self.flows):
|
152 |
+
x = flow(x, x_mask, g=g, reverse=reverse)
|
153 |
+
return x
|
154 |
+
|
155 |
+
def remove_weight_norm(self):
|
156 |
+
for i in range(self.n_flows):
|
157 |
+
self.flows[i * 2].remove_weight_norm()
|
158 |
+
|
159 |
+
|
160 |
+
class PosteriorEncoder(nn.Module):
|
161 |
+
def __init__(
|
162 |
+
self,
|
163 |
+
in_channels,
|
164 |
+
out_channels,
|
165 |
+
hidden_channels,
|
166 |
+
kernel_size,
|
167 |
+
dilation_rate,
|
168 |
+
n_layers,
|
169 |
+
gin_channels=0,
|
170 |
+
):
|
171 |
+
super().__init__()
|
172 |
+
self.in_channels = in_channels
|
173 |
+
self.out_channels = out_channels
|
174 |
+
self.hidden_channels = hidden_channels
|
175 |
+
self.kernel_size = kernel_size
|
176 |
+
self.dilation_rate = dilation_rate
|
177 |
+
self.n_layers = n_layers
|
178 |
+
self.gin_channels = gin_channels
|
179 |
+
|
180 |
+
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
181 |
+
self.enc = modules.WN(
|
182 |
+
hidden_channels,
|
183 |
+
kernel_size,
|
184 |
+
dilation_rate,
|
185 |
+
n_layers,
|
186 |
+
gin_channels=gin_channels,
|
187 |
+
)
|
188 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
189 |
+
|
190 |
+
def forward(self, x, x_lengths, g=None):
|
191 |
+
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
192 |
+
x.dtype
|
193 |
+
)
|
194 |
+
x = self.pre(x) * x_mask
|
195 |
+
x = self.enc(x, x_mask, g=g)
|
196 |
+
stats = self.proj(x) * x_mask
|
197 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
198 |
+
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
|
199 |
+
return z, m, logs, x_mask
|
200 |
+
|
201 |
+
def remove_weight_norm(self):
|
202 |
+
self.enc.remove_weight_norm()
|
203 |
+
|
204 |
+
|
205 |
+
class Generator(torch.nn.Module):
|
206 |
+
def __init__(
|
207 |
+
self,
|
208 |
+
initial_channel,
|
209 |
+
resblock,
|
210 |
+
resblock_kernel_sizes,
|
211 |
+
resblock_dilation_sizes,
|
212 |
+
upsample_rates,
|
213 |
+
upsample_initial_channel,
|
214 |
+
upsample_kernel_sizes,
|
215 |
+
gin_channels=0,
|
216 |
+
):
|
217 |
+
super(Generator, self).__init__()
|
218 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
219 |
+
self.num_upsamples = len(upsample_rates)
|
220 |
+
self.conv_pre = Conv1d(
|
221 |
+
initial_channel, upsample_initial_channel, 7, 1, padding=3
|
222 |
+
)
|
223 |
+
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
|
224 |
+
|
225 |
+
self.ups = nn.ModuleList()
|
226 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
227 |
+
self.ups.append(
|
228 |
+
weight_norm(
|
229 |
+
ConvTranspose1d(
|
230 |
+
upsample_initial_channel // (2**i),
|
231 |
+
upsample_initial_channel // (2 ** (i + 1)),
|
232 |
+
k,
|
233 |
+
u,
|
234 |
+
padding=(k - u) // 2,
|
235 |
+
)
|
236 |
+
)
|
237 |
+
)
|
238 |
+
|
239 |
+
self.resblocks = nn.ModuleList()
|
240 |
+
for i in range(len(self.ups)):
|
241 |
+
ch = upsample_initial_channel // (2 ** (i + 1))
|
242 |
+
for j, (k, d) in enumerate(
|
243 |
+
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
244 |
+
):
|
245 |
+
self.resblocks.append(resblock(ch, k, d))
|
246 |
+
|
247 |
+
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
248 |
+
self.ups.apply(init_weights)
|
249 |
+
|
250 |
+
if gin_channels != 0:
|
251 |
+
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
252 |
+
|
253 |
+
def forward(self, x, g=None):
|
254 |
+
x = self.conv_pre(x)
|
255 |
+
if g is not None:
|
256 |
+
x = x + self.cond(g)
|
257 |
+
|
258 |
+
for i in range(self.num_upsamples):
|
259 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
260 |
+
x = self.ups[i](x)
|
261 |
+
xs = None
|
262 |
+
for j in range(self.num_kernels):
|
263 |
+
if xs is None:
|
264 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
265 |
+
else:
|
266 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
267 |
+
x = xs / self.num_kernels
|
268 |
+
x = F.leaky_relu(x)
|
269 |
+
x = self.conv_post(x)
|
270 |
+
x = torch.tanh(x)
|
271 |
+
|
272 |
+
return x
|
273 |
+
|
274 |
+
def remove_weight_norm(self):
|
275 |
+
for l in self.ups:
|
276 |
+
remove_weight_norm(l)
|
277 |
+
for l in self.resblocks:
|
278 |
+
l.remove_weight_norm()
|
279 |
+
|
280 |
+
|
281 |
+
class SineGen(torch.nn.Module):
|
282 |
+
"""Definition of sine generator
|
283 |
+
SineGen(samp_rate, harmonic_num = 0,
|
284 |
+
sine_amp = 0.1, noise_std = 0.003,
|
285 |
+
voiced_threshold = 0,
|
286 |
+
flag_for_pulse=False)
|
287 |
+
samp_rate: sampling rate in Hz
|
288 |
+
harmonic_num: number of harmonic overtones (default 0)
|
289 |
+
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
290 |
+
noise_std: std of Gaussian noise (default 0.003)
|
291 |
+
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
292 |
+
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
293 |
+
Note: when flag_for_pulse is True, the first time step of a voiced
|
294 |
+
segment is always sin(np.pi) or cos(0)
|
295 |
+
"""
|
296 |
+
|
297 |
+
def __init__(
|
298 |
+
self,
|
299 |
+
samp_rate,
|
300 |
+
harmonic_num=0,
|
301 |
+
sine_amp=0.1,
|
302 |
+
noise_std=0.003,
|
303 |
+
voiced_threshold=0,
|
304 |
+
flag_for_pulse=False,
|
305 |
+
):
|
306 |
+
super(SineGen, self).__init__()
|
307 |
+
self.sine_amp = sine_amp
|
308 |
+
self.noise_std = noise_std
|
309 |
+
self.harmonic_num = harmonic_num
|
310 |
+
self.dim = self.harmonic_num + 1
|
311 |
+
self.sampling_rate = samp_rate
|
312 |
+
self.voiced_threshold = voiced_threshold
|
313 |
+
|
314 |
+
def _f02uv(self, f0):
|
315 |
+
# generate uv signal
|
316 |
+
uv = torch.ones_like(f0)
|
317 |
+
uv = uv * (f0 > self.voiced_threshold)
|
318 |
+
return uv
|
319 |
+
|
320 |
+
def forward(self, f0, upp):
|
321 |
+
"""sine_tensor, uv = forward(f0)
|
322 |
+
input F0: tensor(batchsize=1, length, dim=1)
|
323 |
+
f0 for unvoiced steps should be 0
|
324 |
+
output sine_tensor: tensor(batchsize=1, length, dim)
|
325 |
+
output uv: tensor(batchsize=1, length, 1)
|
326 |
+
"""
|
327 |
+
with torch.no_grad():
|
328 |
+
f0 = f0[:, None].transpose(1, 2)
|
329 |
+
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
|
330 |
+
# fundamental component
|
331 |
+
f0_buf[:, :, 0] = f0[:, :, 0]
|
332 |
+
for idx in np.arange(self.harmonic_num):
|
333 |
+
f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
|
334 |
+
idx + 2
|
335 |
+
) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
|
336 |
+
rad_values = (f0_buf / self.sampling_rate) % 1 ###%1 means that the product of n_har cannot be post-processed and optimized
|
337 |
+
rand_ini = torch.rand(
|
338 |
+
f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device
|
339 |
+
)
|
340 |
+
rand_ini[:, 0] = 0
|
341 |
+
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
342 |
+
tmp_over_one = torch.cumsum(rad_values, 1) # % 1 #####%1 means that the following cumsum can no longer be optimized
|
343 |
+
tmp_over_one *= upp
|
344 |
+
tmp_over_one = F.interpolate(
|
345 |
+
tmp_over_one.transpose(2, 1),
|
346 |
+
scale_factor=upp,
|
347 |
+
mode="linear",
|
348 |
+
align_corners=True,
|
349 |
+
).transpose(2, 1)
|
350 |
+
rad_values = F.interpolate(
|
351 |
+
rad_values.transpose(2, 1), scale_factor=upp, mode="nearest"
|
352 |
+
).transpose(
|
353 |
+
2, 1
|
354 |
+
) #######
|
355 |
+
tmp_over_one %= 1
|
356 |
+
tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
|
357 |
+
cumsum_shift = torch.zeros_like(rad_values)
|
358 |
+
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
359 |
+
sine_waves = torch.sin(
|
360 |
+
torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi
|
361 |
+
)
|
362 |
+
sine_waves = sine_waves * self.sine_amp
|
363 |
+
uv = self._f02uv(f0)
|
364 |
+
uv = F.interpolate(
|
365 |
+
uv.transpose(2, 1), scale_factor=upp, mode="nearest"
|
366 |
+
).transpose(2, 1)
|
367 |
+
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
368 |
+
noise = noise_amp * torch.randn_like(sine_waves)
|
369 |
+
sine_waves = sine_waves * uv + noise
|
370 |
+
return sine_waves, uv, noise
|
371 |
+
|
372 |
+
|
373 |
+
class SourceModuleHnNSF(torch.nn.Module):
|
374 |
+
"""SourceModule for hn-nsf
|
375 |
+
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
376 |
+
add_noise_std=0.003, voiced_threshod=0)
|
377 |
+
sampling_rate: sampling_rate in Hz
|
378 |
+
harmonic_num: number of harmonic above F0 (default: 0)
|
379 |
+
sine_amp: amplitude of sine source signal (default: 0.1)
|
380 |
+
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
381 |
+
note that amplitude of noise in unvoiced is decided
|
382 |
+
by sine_amp
|
383 |
+
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
384 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
385 |
+
F0_sampled (batchsize, length, 1)
|
386 |
+
Sine_source (batchsize, length, 1)
|
387 |
+
noise_source (batchsize, length 1)
|
388 |
+
uv (batchsize, length, 1)
|
389 |
+
"""
|
390 |
+
|
391 |
+
def __init__(
|
392 |
+
self,
|
393 |
+
sampling_rate,
|
394 |
+
harmonic_num=0,
|
395 |
+
sine_amp=0.1,
|
396 |
+
add_noise_std=0.003,
|
397 |
+
voiced_threshod=0,
|
398 |
+
is_half=True,
|
399 |
+
):
|
400 |
+
super(SourceModuleHnNSF, self).__init__()
|
401 |
+
|
402 |
+
self.sine_amp = sine_amp
|
403 |
+
self.noise_std = add_noise_std
|
404 |
+
self.is_half = is_half
|
405 |
+
# to produce sine waveforms
|
406 |
+
self.l_sin_gen = SineGen(
|
407 |
+
sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
|
408 |
+
)
|
409 |
+
|
410 |
+
# to merge source harmonics into a single excitation
|
411 |
+
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
412 |
+
self.l_tanh = torch.nn.Tanh()
|
413 |
+
|
414 |
+
def forward(self, x, upp=None):
|
415 |
+
sine_wavs, uv, _ = self.l_sin_gen(x, upp)
|
416 |
+
if self.is_half:
|
417 |
+
sine_wavs = sine_wavs.half()
|
418 |
+
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
419 |
+
return sine_merge, None, None # noise, uv
|
420 |
+
|
421 |
+
|
422 |
+
class GeneratorNSF(torch.nn.Module):
|
423 |
+
def __init__(
|
424 |
+
self,
|
425 |
+
initial_channel,
|
426 |
+
resblock,
|
427 |
+
resblock_kernel_sizes,
|
428 |
+
resblock_dilation_sizes,
|
429 |
+
upsample_rates,
|
430 |
+
upsample_initial_channel,
|
431 |
+
upsample_kernel_sizes,
|
432 |
+
gin_channels,
|
433 |
+
sr,
|
434 |
+
is_half=False,
|
435 |
+
):
|
436 |
+
super(GeneratorNSF, self).__init__()
|
437 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
438 |
+
self.num_upsamples = len(upsample_rates)
|
439 |
+
|
440 |
+
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
|
441 |
+
self.m_source = SourceModuleHnNSF(
|
442 |
+
sampling_rate=sr, harmonic_num=0, is_half=is_half
|
443 |
+
)
|
444 |
+
self.noise_convs = nn.ModuleList()
|
445 |
+
self.conv_pre = Conv1d(
|
446 |
+
initial_channel, upsample_initial_channel, 7, 1, padding=3
|
447 |
+
)
|
448 |
+
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
|
449 |
+
|
450 |
+
self.ups = nn.ModuleList()
|
451 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
452 |
+
c_cur = upsample_initial_channel // (2 ** (i + 1))
|
453 |
+
self.ups.append(
|
454 |
+
weight_norm(
|
455 |
+
ConvTranspose1d(
|
456 |
+
upsample_initial_channel // (2**i),
|
457 |
+
upsample_initial_channel // (2 ** (i + 1)),
|
458 |
+
k,
|
459 |
+
u,
|
460 |
+
padding=(k - u) // 2,
|
461 |
+
)
|
462 |
+
)
|
463 |
+
)
|
464 |
+
if i + 1 < len(upsample_rates):
|
465 |
+
stride_f0 = np.prod(upsample_rates[i + 1 :])
|
466 |
+
self.noise_convs.append(
|
467 |
+
Conv1d(
|
468 |
+
1,
|
469 |
+
c_cur,
|
470 |
+
kernel_size=stride_f0 * 2,
|
471 |
+
stride=stride_f0,
|
472 |
+
padding=stride_f0 // 2,
|
473 |
+
)
|
474 |
+
)
|
475 |
+
else:
|
476 |
+
self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
|
477 |
+
|
478 |
+
self.resblocks = nn.ModuleList()
|
479 |
+
for i in range(len(self.ups)):
|
480 |
+
ch = upsample_initial_channel // (2 ** (i + 1))
|
481 |
+
for j, (k, d) in enumerate(
|
482 |
+
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
483 |
+
):
|
484 |
+
self.resblocks.append(resblock(ch, k, d))
|
485 |
+
|
486 |
+
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
487 |
+
self.ups.apply(init_weights)
|
488 |
+
|
489 |
+
if gin_channels != 0:
|
490 |
+
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
491 |
+
|
492 |
+
self.upp = np.prod(upsample_rates)
|
493 |
+
|
494 |
+
def forward(self, x, f0, g=None):
|
495 |
+
har_source, noi_source, uv = self.m_source(f0, self.upp)
|
496 |
+
har_source = har_source.transpose(1, 2)
|
497 |
+
x = self.conv_pre(x)
|
498 |
+
if g is not None:
|
499 |
+
x = x + self.cond(g)
|
500 |
+
|
501 |
+
for i in range(self.num_upsamples):
|
502 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
503 |
+
x = self.ups[i](x)
|
504 |
+
x_source = self.noise_convs[i](har_source)
|
505 |
+
x = x + x_source
|
506 |
+
xs = None
|
507 |
+
for j in range(self.num_kernels):
|
508 |
+
if xs is None:
|
509 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
510 |
+
else:
|
511 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
512 |
+
x = xs / self.num_kernels
|
513 |
+
x = F.leaky_relu(x)
|
514 |
+
x = self.conv_post(x)
|
515 |
+
x = torch.tanh(x)
|
516 |
+
return x
|
517 |
+
|
518 |
+
def remove_weight_norm(self):
|
519 |
+
for l in self.ups:
|
520 |
+
remove_weight_norm(l)
|
521 |
+
for l in self.resblocks:
|
522 |
+
l.remove_weight_norm()
|
523 |
+
|
524 |
+
|
525 |
+
sr2sr = {
|
526 |
+
"32k": 32000,
|
527 |
+
"40k": 40000,
|
528 |
+
"48k": 48000,
|
529 |
+
}
|
530 |
+
|
531 |
+
|
532 |
+
class SynthesizerTrnMs256NSFsid(nn.Module):
|
533 |
+
def __init__(
|
534 |
+
self,
|
535 |
+
spec_channels,
|
536 |
+
segment_size,
|
537 |
+
inter_channels,
|
538 |
+
hidden_channels,
|
539 |
+
filter_channels,
|
540 |
+
n_heads,
|
541 |
+
n_layers,
|
542 |
+
kernel_size,
|
543 |
+
p_dropout,
|
544 |
+
resblock,
|
545 |
+
resblock_kernel_sizes,
|
546 |
+
resblock_dilation_sizes,
|
547 |
+
upsample_rates,
|
548 |
+
upsample_initial_channel,
|
549 |
+
upsample_kernel_sizes,
|
550 |
+
spk_embed_dim,
|
551 |
+
gin_channels,
|
552 |
+
sr,
|
553 |
+
**kwargs
|
554 |
+
):
|
555 |
+
super().__init__()
|
556 |
+
if type(sr) == type("strr"):
|
557 |
+
sr = sr2sr[sr]
|
558 |
+
self.spec_channels = spec_channels
|
559 |
+
self.inter_channels = inter_channels
|
560 |
+
self.hidden_channels = hidden_channels
|
561 |
+
self.filter_channels = filter_channels
|
562 |
+
self.n_heads = n_heads
|
563 |
+
self.n_layers = n_layers
|
564 |
+
self.kernel_size = kernel_size
|
565 |
+
self.p_dropout = p_dropout
|
566 |
+
self.resblock = resblock
|
567 |
+
self.resblock_kernel_sizes = resblock_kernel_sizes
|
568 |
+
self.resblock_dilation_sizes = resblock_dilation_sizes
|
569 |
+
self.upsample_rates = upsample_rates
|
570 |
+
self.upsample_initial_channel = upsample_initial_channel
|
571 |
+
self.upsample_kernel_sizes = upsample_kernel_sizes
|
572 |
+
self.segment_size = segment_size
|
573 |
+
self.gin_channels = gin_channels
|
574 |
+
# self.hop_length = hop_length#
|
575 |
+
self.spk_embed_dim = spk_embed_dim
|
576 |
+
self.enc_p = TextEncoder256(
|
577 |
+
inter_channels,
|
578 |
+
hidden_channels,
|
579 |
+
filter_channels,
|
580 |
+
n_heads,
|
581 |
+
n_layers,
|
582 |
+
kernel_size,
|
583 |
+
p_dropout,
|
584 |
+
)
|
585 |
+
self.dec = GeneratorNSF(
|
586 |
+
inter_channels,
|
587 |
+
resblock,
|
588 |
+
resblock_kernel_sizes,
|
589 |
+
resblock_dilation_sizes,
|
590 |
+
upsample_rates,
|
591 |
+
upsample_initial_channel,
|
592 |
+
upsample_kernel_sizes,
|
593 |
+
gin_channels=gin_channels,
|
594 |
+
sr=sr,
|
595 |
+
is_half=kwargs["is_half"],
|
596 |
+
)
|
597 |
+
self.enc_q = PosteriorEncoder(
|
598 |
+
spec_channels,
|
599 |
+
inter_channels,
|
600 |
+
hidden_channels,
|
601 |
+
5,
|
602 |
+
1,
|
603 |
+
16,
|
604 |
+
gin_channels=gin_channels,
|
605 |
+
)
|
606 |
+
self.flow = ResidualCouplingBlock(
|
607 |
+
inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
|
608 |
+
)
|
609 |
+
self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
|
610 |
+
print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
|
611 |
+
|
612 |
+
def remove_weight_norm(self):
|
613 |
+
self.dec.remove_weight_norm()
|
614 |
+
self.flow.remove_weight_norm()
|
615 |
+
self.enc_q.remove_weight_norm()
|
616 |
+
|
617 |
+
def forward(
|
618 |
+
self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds
|
619 |
+
): # Here ds is id, [bs,1]
|
620 |
+
# print(1,pitch.shape)#[bs,t]
|
621 |
+
g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1 is t, broadcast
|
622 |
+
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
|
623 |
+
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
|
624 |
+
z_p = self.flow(z, y_mask, g=g)
|
625 |
+
z_slice, ids_slice = commons.rand_slice_segments(
|
626 |
+
z, y_lengths, self.segment_size
|
627 |
+
)
|
628 |
+
# print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length)
|
629 |
+
pitchf = commons.slice_segments2(pitchf, ids_slice, self.segment_size)
|
630 |
+
# print(-2,pitchf.shape,z_slice.shape)
|
631 |
+
o = self.dec(z_slice, pitchf, g=g)
|
632 |
+
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
|
633 |
+
|
634 |
+
def infer(self, phone, phone_lengths, pitch, nsff0, sid, rate=None):
|
635 |
+
g = self.emb_g(sid).unsqueeze(-1)
|
636 |
+
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
|
637 |
+
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
638 |
+
if rate:
|
639 |
+
head = int(z_p.shape[2] * rate)
|
640 |
+
z_p = z_p[:, :, -head:]
|
641 |
+
x_mask = x_mask[:, :, -head:]
|
642 |
+
nsff0 = nsff0[:, -head:]
|
643 |
+
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
644 |
+
o = self.dec(z * x_mask, nsff0, g=g)
|
645 |
+
return o, x_mask, (z, z_p, m_p, logs_p)
|
646 |
+
|
647 |
+
|
648 |
+
class SynthesizerTrnMs768NSFsid(nn.Module):
|
649 |
+
def __init__(
|
650 |
+
self,
|
651 |
+
spec_channels,
|
652 |
+
segment_size,
|
653 |
+
inter_channels,
|
654 |
+
hidden_channels,
|
655 |
+
filter_channels,
|
656 |
+
n_heads,
|
657 |
+
n_layers,
|
658 |
+
kernel_size,
|
659 |
+
p_dropout,
|
660 |
+
resblock,
|
661 |
+
resblock_kernel_sizes,
|
662 |
+
resblock_dilation_sizes,
|
663 |
+
upsample_rates,
|
664 |
+
upsample_initial_channel,
|
665 |
+
upsample_kernel_sizes,
|
666 |
+
spk_embed_dim,
|
667 |
+
gin_channels,
|
668 |
+
sr,
|
669 |
+
**kwargs
|
670 |
+
):
|
671 |
+
super().__init__()
|
672 |
+
if type(sr) == type("strr"):
|
673 |
+
sr = sr2sr[sr]
|
674 |
+
self.spec_channels = spec_channels
|
675 |
+
self.inter_channels = inter_channels
|
676 |
+
self.hidden_channels = hidden_channels
|
677 |
+
self.filter_channels = filter_channels
|
678 |
+
self.n_heads = n_heads
|
679 |
+
self.n_layers = n_layers
|
680 |
+
self.kernel_size = kernel_size
|
681 |
+
self.p_dropout = p_dropout
|
682 |
+
self.resblock = resblock
|
683 |
+
self.resblock_kernel_sizes = resblock_kernel_sizes
|
684 |
+
self.resblock_dilation_sizes = resblock_dilation_sizes
|
685 |
+
self.upsample_rates = upsample_rates
|
686 |
+
self.upsample_initial_channel = upsample_initial_channel
|
687 |
+
self.upsample_kernel_sizes = upsample_kernel_sizes
|
688 |
+
self.segment_size = segment_size
|
689 |
+
self.gin_channels = gin_channels
|
690 |
+
# self.hop_length = hop_length#
|
691 |
+
self.spk_embed_dim = spk_embed_dim
|
692 |
+
self.enc_p = TextEncoder768(
|
693 |
+
inter_channels,
|
694 |
+
hidden_channels,
|
695 |
+
filter_channels,
|
696 |
+
n_heads,
|
697 |
+
n_layers,
|
698 |
+
kernel_size,
|
699 |
+
p_dropout,
|
700 |
+
)
|
701 |
+
self.dec = GeneratorNSF(
|
702 |
+
inter_channels,
|
703 |
+
resblock,
|
704 |
+
resblock_kernel_sizes,
|
705 |
+
resblock_dilation_sizes,
|
706 |
+
upsample_rates,
|
707 |
+
upsample_initial_channel,
|
708 |
+
upsample_kernel_sizes,
|
709 |
+
gin_channels=gin_channels,
|
710 |
+
sr=sr,
|
711 |
+
is_half=kwargs["is_half"],
|
712 |
+
)
|
713 |
+
self.enc_q = PosteriorEncoder(
|
714 |
+
spec_channels,
|
715 |
+
inter_channels,
|
716 |
+
hidden_channels,
|
717 |
+
5,
|
718 |
+
1,
|
719 |
+
16,
|
720 |
+
gin_channels=gin_channels,
|
721 |
+
)
|
722 |
+
self.flow = ResidualCouplingBlock(
|
723 |
+
inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
|
724 |
+
)
|
725 |
+
self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
|
726 |
+
print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
|
727 |
+
|
728 |
+
def remove_weight_norm(self):
|
729 |
+
self.dec.remove_weight_norm()
|
730 |
+
self.flow.remove_weight_norm()
|
731 |
+
self.enc_q.remove_weight_norm()
|
732 |
+
|
733 |
+
def forward(
|
734 |
+
self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds
|
735 |
+
): # Here ds is id,[bs,1]
|
736 |
+
# print(1,pitch.shape)#[bs,t]
|
737 |
+
g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1 is t, broadcast
|
738 |
+
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
|
739 |
+
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
|
740 |
+
z_p = self.flow(z, y_mask, g=g)
|
741 |
+
z_slice, ids_slice = commons.rand_slice_segments(
|
742 |
+
z, y_lengths, self.segment_size
|
743 |
+
)
|
744 |
+
# print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length)
|
745 |
+
pitchf = commons.slice_segments2(pitchf, ids_slice, self.segment_size)
|
746 |
+
# print(-2,pitchf.shape,z_slice.shape)
|
747 |
+
o = self.dec(z_slice, pitchf, g=g)
|
748 |
+
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
|
749 |
+
|
750 |
+
def infer(self, phone, phone_lengths, pitch, nsff0, sid, rate=None):
|
751 |
+
g = self.emb_g(sid).unsqueeze(-1)
|
752 |
+
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
|
753 |
+
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
754 |
+
if rate:
|
755 |
+
head = int(z_p.shape[2] * rate)
|
756 |
+
z_p = z_p[:, :, -head:]
|
757 |
+
x_mask = x_mask[:, :, -head:]
|
758 |
+
nsff0 = nsff0[:, -head:]
|
759 |
+
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
760 |
+
o = self.dec(z * x_mask, nsff0, g=g)
|
761 |
+
return o, x_mask, (z, z_p, m_p, logs_p)
|
762 |
+
|
763 |
+
|
764 |
+
class SynthesizerTrnMs256NSFsid_nono(nn.Module):
|
765 |
+
def __init__(
|
766 |
+
self,
|
767 |
+
spec_channels,
|
768 |
+
segment_size,
|
769 |
+
inter_channels,
|
770 |
+
hidden_channels,
|
771 |
+
filter_channels,
|
772 |
+
n_heads,
|
773 |
+
n_layers,
|
774 |
+
kernel_size,
|
775 |
+
p_dropout,
|
776 |
+
resblock,
|
777 |
+
resblock_kernel_sizes,
|
778 |
+
resblock_dilation_sizes,
|
779 |
+
upsample_rates,
|
780 |
+
upsample_initial_channel,
|
781 |
+
upsample_kernel_sizes,
|
782 |
+
spk_embed_dim,
|
783 |
+
gin_channels,
|
784 |
+
sr=None,
|
785 |
+
**kwargs
|
786 |
+
):
|
787 |
+
super().__init__()
|
788 |
+
self.spec_channels = spec_channels
|
789 |
+
self.inter_channels = inter_channels
|
790 |
+
self.hidden_channels = hidden_channels
|
791 |
+
self.filter_channels = filter_channels
|
792 |
+
self.n_heads = n_heads
|
793 |
+
self.n_layers = n_layers
|
794 |
+
self.kernel_size = kernel_size
|
795 |
+
self.p_dropout = p_dropout
|
796 |
+
self.resblock = resblock
|
797 |
+
self.resblock_kernel_sizes = resblock_kernel_sizes
|
798 |
+
self.resblock_dilation_sizes = resblock_dilation_sizes
|
799 |
+
self.upsample_rates = upsample_rates
|
800 |
+
self.upsample_initial_channel = upsample_initial_channel
|
801 |
+
self.upsample_kernel_sizes = upsample_kernel_sizes
|
802 |
+
self.segment_size = segment_size
|
803 |
+
self.gin_channels = gin_channels
|
804 |
+
# self.hop_length = hop_length#
|
805 |
+
self.spk_embed_dim = spk_embed_dim
|
806 |
+
self.enc_p = TextEncoder256(
|
807 |
+
inter_channels,
|
808 |
+
hidden_channels,
|
809 |
+
filter_channels,
|
810 |
+
n_heads,
|
811 |
+
n_layers,
|
812 |
+
kernel_size,
|
813 |
+
p_dropout,
|
814 |
+
f0=False,
|
815 |
+
)
|
816 |
+
self.dec = Generator(
|
817 |
+
inter_channels,
|
818 |
+
resblock,
|
819 |
+
resblock_kernel_sizes,
|
820 |
+
resblock_dilation_sizes,
|
821 |
+
upsample_rates,
|
822 |
+
upsample_initial_channel,
|
823 |
+
upsample_kernel_sizes,
|
824 |
+
gin_channels=gin_channels,
|
825 |
+
)
|
826 |
+
self.enc_q = PosteriorEncoder(
|
827 |
+
spec_channels,
|
828 |
+
inter_channels,
|
829 |
+
hidden_channels,
|
830 |
+
5,
|
831 |
+
1,
|
832 |
+
16,
|
833 |
+
gin_channels=gin_channels,
|
834 |
+
)
|
835 |
+
self.flow = ResidualCouplingBlock(
|
836 |
+
inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
|
837 |
+
)
|
838 |
+
self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
|
839 |
+
print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
|
840 |
+
|
841 |
+
def remove_weight_norm(self):
|
842 |
+
self.dec.remove_weight_norm()
|
843 |
+
self.flow.remove_weight_norm()
|
844 |
+
self.enc_q.remove_weight_norm()
|
845 |
+
|
846 |
+
def forward(self, phone, phone_lengths, y, y_lengths, ds): # Here ds is id,[bs,1]
|
847 |
+
g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1 is t, broadcast
|
848 |
+
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
|
849 |
+
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
|
850 |
+
z_p = self.flow(z, y_mask, g=g)
|
851 |
+
z_slice, ids_slice = commons.rand_slice_segments(
|
852 |
+
z, y_lengths, self.segment_size
|
853 |
+
)
|
854 |
+
o = self.dec(z_slice, g=g)
|
855 |
+
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
|
856 |
+
|
857 |
+
def infer(self, phone, phone_lengths, sid, rate=None):
|
858 |
+
g = self.emb_g(sid).unsqueeze(-1)
|
859 |
+
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
|
860 |
+
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
861 |
+
if rate:
|
862 |
+
head = int(z_p.shape[2] * rate)
|
863 |
+
z_p = z_p[:, :, -head:]
|
864 |
+
x_mask = x_mask[:, :, -head:]
|
865 |
+
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
866 |
+
o = self.dec(z * x_mask, g=g)
|
867 |
+
return o, x_mask, (z, z_p, m_p, logs_p)
|
868 |
+
|
869 |
+
|
870 |
+
class SynthesizerTrnMs768NSFsid_nono(nn.Module):
|
871 |
+
def __init__(
|
872 |
+
self,
|
873 |
+
spec_channels,
|
874 |
+
segment_size,
|
875 |
+
inter_channels,
|
876 |
+
hidden_channels,
|
877 |
+
filter_channels,
|
878 |
+
n_heads,
|
879 |
+
n_layers,
|
880 |
+
kernel_size,
|
881 |
+
p_dropout,
|
882 |
+
resblock,
|
883 |
+
resblock_kernel_sizes,
|
884 |
+
resblock_dilation_sizes,
|
885 |
+
upsample_rates,
|
886 |
+
upsample_initial_channel,
|
887 |
+
upsample_kernel_sizes,
|
888 |
+
spk_embed_dim,
|
889 |
+
gin_channels,
|
890 |
+
sr=None,
|
891 |
+
**kwargs
|
892 |
+
):
|
893 |
+
super().__init__()
|
894 |
+
self.spec_channels = spec_channels
|
895 |
+
self.inter_channels = inter_channels
|
896 |
+
self.hidden_channels = hidden_channels
|
897 |
+
self.filter_channels = filter_channels
|
898 |
+
self.n_heads = n_heads
|
899 |
+
self.n_layers = n_layers
|
900 |
+
self.kernel_size = kernel_size
|
901 |
+
self.p_dropout = p_dropout
|
902 |
+
self.resblock = resblock
|
903 |
+
self.resblock_kernel_sizes = resblock_kernel_sizes
|
904 |
+
self.resblock_dilation_sizes = resblock_dilation_sizes
|
905 |
+
self.upsample_rates = upsample_rates
|
906 |
+
self.upsample_initial_channel = upsample_initial_channel
|
907 |
+
self.upsample_kernel_sizes = upsample_kernel_sizes
|
908 |
+
self.segment_size = segment_size
|
909 |
+
self.gin_channels = gin_channels
|
910 |
+
# self.hop_length = hop_length#
|
911 |
+
self.spk_embed_dim = spk_embed_dim
|
912 |
+
self.enc_p = TextEncoder768(
|
913 |
+
inter_channels,
|
914 |
+
hidden_channels,
|
915 |
+
filter_channels,
|
916 |
+
n_heads,
|
917 |
+
n_layers,
|
918 |
+
kernel_size,
|
919 |
+
p_dropout,
|
920 |
+
f0=False,
|
921 |
+
)
|
922 |
+
self.dec = Generator(
|
923 |
+
inter_channels,
|
924 |
+
resblock,
|
925 |
+
resblock_kernel_sizes,
|
926 |
+
resblock_dilation_sizes,
|
927 |
+
upsample_rates,
|
928 |
+
upsample_initial_channel,
|
929 |
+
upsample_kernel_sizes,
|
930 |
+
gin_channels=gin_channels,
|
931 |
+
)
|
932 |
+
self.enc_q = PosteriorEncoder(
|
933 |
+
spec_channels,
|
934 |
+
inter_channels,
|
935 |
+
hidden_channels,
|
936 |
+
5,
|
937 |
+
1,
|
938 |
+
16,
|
939 |
+
gin_channels=gin_channels,
|
940 |
+
)
|
941 |
+
self.flow = ResidualCouplingBlock(
|
942 |
+
inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
|
943 |
+
)
|
944 |
+
self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
|
945 |
+
print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
|
946 |
+
|
947 |
+
def remove_weight_norm(self):
|
948 |
+
self.dec.remove_weight_norm()
|
949 |
+
self.flow.remove_weight_norm()
|
950 |
+
self.enc_q.remove_weight_norm()
|
951 |
+
|
952 |
+
def forward(self, phone, phone_lengths, y, y_lengths, ds): # Here ds is id,[bs,1]
|
953 |
+
g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1 is t, broadcast
|
954 |
+
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
|
955 |
+
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
|
956 |
+
z_p = self.flow(z, y_mask, g=g)
|
957 |
+
z_slice, ids_slice = commons.rand_slice_segments(
|
958 |
+
z, y_lengths, self.segment_size
|
959 |
+
)
|
960 |
+
o = self.dec(z_slice, g=g)
|
961 |
+
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
|
962 |
+
|
963 |
+
def infer(self, phone, phone_lengths, sid, rate=None):
|
964 |
+
g = self.emb_g(sid).unsqueeze(-1)
|
965 |
+
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
|
966 |
+
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
967 |
+
if rate:
|
968 |
+
head = int(z_p.shape[2] * rate)
|
969 |
+
z_p = z_p[:, :, -head:]
|
970 |
+
x_mask = x_mask[:, :, -head:]
|
971 |
+
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
972 |
+
o = self.dec(z * x_mask, g=g)
|
973 |
+
return o, x_mask, (z, z_p, m_p, logs_p)
|
974 |
+
|
975 |
+
|
976 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
977 |
+
def __init__(self, use_spectral_norm=False):
|
978 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
979 |
+
periods = [2, 3, 5, 7, 11, 17]
|
980 |
+
# periods = [3, 5, 7, 11, 17, 23, 37]
|
981 |
+
|
982 |
+
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
983 |
+
discs = discs + [
|
984 |
+
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
|
985 |
+
]
|
986 |
+
self.discriminators = nn.ModuleList(discs)
|
987 |
+
|
988 |
+
def forward(self, y, y_hat):
|
989 |
+
y_d_rs = [] #
|
990 |
+
y_d_gs = []
|
991 |
+
fmap_rs = []
|
992 |
+
fmap_gs = []
|
993 |
+
for i, d in enumerate(self.discriminators):
|
994 |
+
y_d_r, fmap_r = d(y)
|
995 |
+
y_d_g, fmap_g = d(y_hat)
|
996 |
+
# for j in range(len(fmap_r)):
|
997 |
+
# print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
|
998 |
+
y_d_rs.append(y_d_r)
|
999 |
+
y_d_gs.append(y_d_g)
|
1000 |
+
fmap_rs.append(fmap_r)
|
1001 |
+
fmap_gs.append(fmap_g)
|
1002 |
+
|
1003 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
1004 |
+
|
1005 |
+
|
1006 |
+
class MultiPeriodDiscriminatorV2(torch.nn.Module):
|
1007 |
+
def __init__(self, use_spectral_norm=False):
|
1008 |
+
super(MultiPeriodDiscriminatorV2, self).__init__()
|
1009 |
+
# periods = [2, 3, 5, 7, 11, 17]
|
1010 |
+
periods = [2, 3, 5, 7, 11, 17, 23, 37]
|
1011 |
+
|
1012 |
+
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
1013 |
+
discs = discs + [
|
1014 |
+
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
|
1015 |
+
]
|
1016 |
+
self.discriminators = nn.ModuleList(discs)
|
1017 |
+
|
1018 |
+
def forward(self, y, y_hat):
|
1019 |
+
y_d_rs = [] #
|
1020 |
+
y_d_gs = []
|
1021 |
+
fmap_rs = []
|
1022 |
+
fmap_gs = []
|
1023 |
+
for i, d in enumerate(self.discriminators):
|
1024 |
+
y_d_r, fmap_r = d(y)
|
1025 |
+
y_d_g, fmap_g = d(y_hat)
|
1026 |
+
# for j in range(len(fmap_r)):
|
1027 |
+
# print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
|
1028 |
+
y_d_rs.append(y_d_r)
|
1029 |
+
y_d_gs.append(y_d_g)
|
1030 |
+
fmap_rs.append(fmap_r)
|
1031 |
+
fmap_gs.append(fmap_g)
|
1032 |
+
|
1033 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
1034 |
+
|
1035 |
+
|
1036 |
+
class DiscriminatorS(torch.nn.Module):
|
1037 |
+
def __init__(self, use_spectral_norm=False):
|
1038 |
+
super(DiscriminatorS, self).__init__()
|
1039 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
1040 |
+
self.convs = nn.ModuleList(
|
1041 |
+
[
|
1042 |
+
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
1043 |
+
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
|
1044 |
+
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
|
1045 |
+
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
|
1046 |
+
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
|
1047 |
+
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
1048 |
+
]
|
1049 |
+
)
|
1050 |
+
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
1051 |
+
|
1052 |
+
def forward(self, x):
|
1053 |
+
fmap = []
|
1054 |
+
|
1055 |
+
for l in self.convs:
|
1056 |
+
x = l(x)
|
1057 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
1058 |
+
fmap.append(x)
|
1059 |
+
x = self.conv_post(x)
|
1060 |
+
fmap.append(x)
|
1061 |
+
x = torch.flatten(x, 1, -1)
|
1062 |
+
|
1063 |
+
return x, fmap
|
1064 |
+
|
1065 |
+
|
1066 |
+
class DiscriminatorP(torch.nn.Module):
|
1067 |
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
1068 |
+
super(DiscriminatorP, self).__init__()
|
1069 |
+
self.period = period
|
1070 |
+
self.use_spectral_norm = use_spectral_norm
|
1071 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
1072 |
+
self.convs = nn.ModuleList(
|
1073 |
+
[
|
1074 |
+
norm_f(
|
1075 |
+
Conv2d(
|
1076 |
+
1,
|
1077 |
+
32,
|
1078 |
+
(kernel_size, 1),
|
1079 |
+
(stride, 1),
|
1080 |
+
padding=(get_padding(kernel_size, 1), 0),
|
1081 |
+
)
|
1082 |
+
),
|
1083 |
+
norm_f(
|
1084 |
+
Conv2d(
|
1085 |
+
32,
|
1086 |
+
128,
|
1087 |
+
(kernel_size, 1),
|
1088 |
+
(stride, 1),
|
1089 |
+
padding=(get_padding(kernel_size, 1), 0),
|
1090 |
+
)
|
1091 |
+
),
|
1092 |
+
norm_f(
|
1093 |
+
Conv2d(
|
1094 |
+
128,
|
1095 |
+
512,
|
1096 |
+
(kernel_size, 1),
|
1097 |
+
(stride, 1),
|
1098 |
+
padding=(get_padding(kernel_size, 1), 0),
|
1099 |
+
)
|
1100 |
+
),
|
1101 |
+
norm_f(
|
1102 |
+
Conv2d(
|
1103 |
+
512,
|
1104 |
+
1024,
|
1105 |
+
(kernel_size, 1),
|
1106 |
+
(stride, 1),
|
1107 |
+
padding=(get_padding(kernel_size, 1), 0),
|
1108 |
+
)
|
1109 |
+
),
|
1110 |
+
norm_f(
|
1111 |
+
Conv2d(
|
1112 |
+
1024,
|
1113 |
+
1024,
|
1114 |
+
(kernel_size, 1),
|
1115 |
+
1,
|
1116 |
+
padding=(get_padding(kernel_size, 1), 0),
|
1117 |
+
)
|
1118 |
+
),
|
1119 |
+
]
|
1120 |
+
)
|
1121 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
1122 |
+
|
1123 |
+
def forward(self, x):
|
1124 |
+
fmap = []
|
1125 |
+
|
1126 |
+
# 1d to 2d
|
1127 |
+
b, c, t = x.shape
|
1128 |
+
if t % self.period != 0: # pad first
|
1129 |
+
n_pad = self.period - (t % self.period)
|
1130 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
1131 |
+
t = t + n_pad
|
1132 |
+
x = x.view(b, c, t // self.period, self.period)
|
1133 |
+
|
1134 |
+
for l in self.convs:
|
1135 |
+
x = l(x)
|
1136 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
1137 |
+
fmap.append(x)
|
1138 |
+
x = self.conv_post(x)
|
1139 |
+
fmap.append(x)
|
1140 |
+
x = torch.flatten(x, 1, -1)
|
1141 |
+
|
1142 |
+
return x, fmap
|
lib/infer_pack/modules.py
ADDED
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import scipy
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
from torch.nn import functional as F
|
8 |
+
|
9 |
+
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
10 |
+
from torch.nn.utils import weight_norm, remove_weight_norm
|
11 |
+
|
12 |
+
from lib.infer_pack import commons
|
13 |
+
from lib.infer_pack.commons import init_weights, get_padding
|
14 |
+
from lib.infer_pack.transforms import piecewise_rational_quadratic_transform
|
15 |
+
|
16 |
+
|
17 |
+
LRELU_SLOPE = 0.1
|
18 |
+
|
19 |
+
|
20 |
+
class LayerNorm(nn.Module):
|
21 |
+
def __init__(self, channels, eps=1e-5):
|
22 |
+
super().__init__()
|
23 |
+
self.channels = channels
|
24 |
+
self.eps = eps
|
25 |
+
|
26 |
+
self.gamma = nn.Parameter(torch.ones(channels))
|
27 |
+
self.beta = nn.Parameter(torch.zeros(channels))
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
x = x.transpose(1, -1)
|
31 |
+
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
32 |
+
return x.transpose(1, -1)
|
33 |
+
|
34 |
+
|
35 |
+
class ConvReluNorm(nn.Module):
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
in_channels,
|
39 |
+
hidden_channels,
|
40 |
+
out_channels,
|
41 |
+
kernel_size,
|
42 |
+
n_layers,
|
43 |
+
p_dropout,
|
44 |
+
):
|
45 |
+
super().__init__()
|
46 |
+
self.in_channels = in_channels
|
47 |
+
self.hidden_channels = hidden_channels
|
48 |
+
self.out_channels = out_channels
|
49 |
+
self.kernel_size = kernel_size
|
50 |
+
self.n_layers = n_layers
|
51 |
+
self.p_dropout = p_dropout
|
52 |
+
assert n_layers > 1, "Number of layers should be larger than 0."
|
53 |
+
|
54 |
+
self.conv_layers = nn.ModuleList()
|
55 |
+
self.norm_layers = nn.ModuleList()
|
56 |
+
self.conv_layers.append(
|
57 |
+
nn.Conv1d(
|
58 |
+
in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
59 |
+
)
|
60 |
+
)
|
61 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
62 |
+
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
|
63 |
+
for _ in range(n_layers - 1):
|
64 |
+
self.conv_layers.append(
|
65 |
+
nn.Conv1d(
|
66 |
+
hidden_channels,
|
67 |
+
hidden_channels,
|
68 |
+
kernel_size,
|
69 |
+
padding=kernel_size // 2,
|
70 |
+
)
|
71 |
+
)
|
72 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
73 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
74 |
+
self.proj.weight.data.zero_()
|
75 |
+
self.proj.bias.data.zero_()
|
76 |
+
|
77 |
+
def forward(self, x, x_mask):
|
78 |
+
x_org = x
|
79 |
+
for i in range(self.n_layers):
|
80 |
+
x = self.conv_layers[i](x * x_mask)
|
81 |
+
x = self.norm_layers[i](x)
|
82 |
+
x = self.relu_drop(x)
|
83 |
+
x = x_org + self.proj(x)
|
84 |
+
return x * x_mask
|
85 |
+
|
86 |
+
|
87 |
+
class DDSConv(nn.Module):
|
88 |
+
"""
|
89 |
+
Dialted and Depth-Separable Convolution
|
90 |
+
"""
|
91 |
+
|
92 |
+
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
|
93 |
+
super().__init__()
|
94 |
+
self.channels = channels
|
95 |
+
self.kernel_size = kernel_size
|
96 |
+
self.n_layers = n_layers
|
97 |
+
self.p_dropout = p_dropout
|
98 |
+
|
99 |
+
self.drop = nn.Dropout(p_dropout)
|
100 |
+
self.convs_sep = nn.ModuleList()
|
101 |
+
self.convs_1x1 = nn.ModuleList()
|
102 |
+
self.norms_1 = nn.ModuleList()
|
103 |
+
self.norms_2 = nn.ModuleList()
|
104 |
+
for i in range(n_layers):
|
105 |
+
dilation = kernel_size**i
|
106 |
+
padding = (kernel_size * dilation - dilation) // 2
|
107 |
+
self.convs_sep.append(
|
108 |
+
nn.Conv1d(
|
109 |
+
channels,
|
110 |
+
channels,
|
111 |
+
kernel_size,
|
112 |
+
groups=channels,
|
113 |
+
dilation=dilation,
|
114 |
+
padding=padding,
|
115 |
+
)
|
116 |
+
)
|
117 |
+
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
|
118 |
+
self.norms_1.append(LayerNorm(channels))
|
119 |
+
self.norms_2.append(LayerNorm(channels))
|
120 |
+
|
121 |
+
def forward(self, x, x_mask, g=None):
|
122 |
+
if g is not None:
|
123 |
+
x = x + g
|
124 |
+
for i in range(self.n_layers):
|
125 |
+
y = self.convs_sep[i](x * x_mask)
|
126 |
+
y = self.norms_1[i](y)
|
127 |
+
y = F.gelu(y)
|
128 |
+
y = self.convs_1x1[i](y)
|
129 |
+
y = self.norms_2[i](y)
|
130 |
+
y = F.gelu(y)
|
131 |
+
y = self.drop(y)
|
132 |
+
x = x + y
|
133 |
+
return x * x_mask
|
134 |
+
|
135 |
+
|
136 |
+
class WN(torch.nn.Module):
|
137 |
+
def __init__(
|
138 |
+
self,
|
139 |
+
hidden_channels,
|
140 |
+
kernel_size,
|
141 |
+
dilation_rate,
|
142 |
+
n_layers,
|
143 |
+
gin_channels=0,
|
144 |
+
p_dropout=0,
|
145 |
+
):
|
146 |
+
super(WN, self).__init__()
|
147 |
+
assert kernel_size % 2 == 1
|
148 |
+
self.hidden_channels = hidden_channels
|
149 |
+
self.kernel_size = (kernel_size,)
|
150 |
+
self.dilation_rate = dilation_rate
|
151 |
+
self.n_layers = n_layers
|
152 |
+
self.gin_channels = gin_channels
|
153 |
+
self.p_dropout = p_dropout
|
154 |
+
|
155 |
+
self.in_layers = torch.nn.ModuleList()
|
156 |
+
self.res_skip_layers = torch.nn.ModuleList()
|
157 |
+
self.drop = nn.Dropout(p_dropout)
|
158 |
+
|
159 |
+
if gin_channels != 0:
|
160 |
+
cond_layer = torch.nn.Conv1d(
|
161 |
+
gin_channels, 2 * hidden_channels * n_layers, 1
|
162 |
+
)
|
163 |
+
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
|
164 |
+
|
165 |
+
for i in range(n_layers):
|
166 |
+
dilation = dilation_rate**i
|
167 |
+
padding = int((kernel_size * dilation - dilation) / 2)
|
168 |
+
in_layer = torch.nn.Conv1d(
|
169 |
+
hidden_channels,
|
170 |
+
2 * hidden_channels,
|
171 |
+
kernel_size,
|
172 |
+
dilation=dilation,
|
173 |
+
padding=padding,
|
174 |
+
)
|
175 |
+
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
|
176 |
+
self.in_layers.append(in_layer)
|
177 |
+
|
178 |
+
# last one is not necessary
|
179 |
+
if i < n_layers - 1:
|
180 |
+
res_skip_channels = 2 * hidden_channels
|
181 |
+
else:
|
182 |
+
res_skip_channels = hidden_channels
|
183 |
+
|
184 |
+
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
185 |
+
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
|
186 |
+
self.res_skip_layers.append(res_skip_layer)
|
187 |
+
|
188 |
+
def forward(self, x, x_mask, g=None, **kwargs):
|
189 |
+
output = torch.zeros_like(x)
|
190 |
+
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
191 |
+
|
192 |
+
if g is not None:
|
193 |
+
g = self.cond_layer(g)
|
194 |
+
|
195 |
+
for i in range(self.n_layers):
|
196 |
+
x_in = self.in_layers[i](x)
|
197 |
+
if g is not None:
|
198 |
+
cond_offset = i * 2 * self.hidden_channels
|
199 |
+
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
200 |
+
else:
|
201 |
+
g_l = torch.zeros_like(x_in)
|
202 |
+
|
203 |
+
acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
|
204 |
+
acts = self.drop(acts)
|
205 |
+
|
206 |
+
res_skip_acts = self.res_skip_layers[i](acts)
|
207 |
+
if i < self.n_layers - 1:
|
208 |
+
res_acts = res_skip_acts[:, : self.hidden_channels, :]
|
209 |
+
x = (x + res_acts) * x_mask
|
210 |
+
output = output + res_skip_acts[:, self.hidden_channels :, :]
|
211 |
+
else:
|
212 |
+
output = output + res_skip_acts
|
213 |
+
return output * x_mask
|
214 |
+
|
215 |
+
def remove_weight_norm(self):
|
216 |
+
if self.gin_channels != 0:
|
217 |
+
torch.nn.utils.remove_weight_norm(self.cond_layer)
|
218 |
+
for l in self.in_layers:
|
219 |
+
torch.nn.utils.remove_weight_norm(l)
|
220 |
+
for l in self.res_skip_layers:
|
221 |
+
torch.nn.utils.remove_weight_norm(l)
|
222 |
+
|
223 |
+
|
224 |
+
class ResBlock1(torch.nn.Module):
|
225 |
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
226 |
+
super(ResBlock1, self).__init__()
|
227 |
+
self.convs1 = nn.ModuleList(
|
228 |
+
[
|
229 |
+
weight_norm(
|
230 |
+
Conv1d(
|
231 |
+
channels,
|
232 |
+
channels,
|
233 |
+
kernel_size,
|
234 |
+
1,
|
235 |
+
dilation=dilation[0],
|
236 |
+
padding=get_padding(kernel_size, dilation[0]),
|
237 |
+
)
|
238 |
+
),
|
239 |
+
weight_norm(
|
240 |
+
Conv1d(
|
241 |
+
channels,
|
242 |
+
channels,
|
243 |
+
kernel_size,
|
244 |
+
1,
|
245 |
+
dilation=dilation[1],
|
246 |
+
padding=get_padding(kernel_size, dilation[1]),
|
247 |
+
)
|
248 |
+
),
|
249 |
+
weight_norm(
|
250 |
+
Conv1d(
|
251 |
+
channels,
|
252 |
+
channels,
|
253 |
+
kernel_size,
|
254 |
+
1,
|
255 |
+
dilation=dilation[2],
|
256 |
+
padding=get_padding(kernel_size, dilation[2]),
|
257 |
+
)
|
258 |
+
),
|
259 |
+
]
|
260 |
+
)
|
261 |
+
self.convs1.apply(init_weights)
|
262 |
+
|
263 |
+
self.convs2 = nn.ModuleList(
|
264 |
+
[
|
265 |
+
weight_norm(
|
266 |
+
Conv1d(
|
267 |
+
channels,
|
268 |
+
channels,
|
269 |
+
kernel_size,
|
270 |
+
1,
|
271 |
+
dilation=1,
|
272 |
+
padding=get_padding(kernel_size, 1),
|
273 |
+
)
|
274 |
+
),
|
275 |
+
weight_norm(
|
276 |
+
Conv1d(
|
277 |
+
channels,
|
278 |
+
channels,
|
279 |
+
kernel_size,
|
280 |
+
1,
|
281 |
+
dilation=1,
|
282 |
+
padding=get_padding(kernel_size, 1),
|
283 |
+
)
|
284 |
+
),
|
285 |
+
weight_norm(
|
286 |
+
Conv1d(
|
287 |
+
channels,
|
288 |
+
channels,
|
289 |
+
kernel_size,
|
290 |
+
1,
|
291 |
+
dilation=1,
|
292 |
+
padding=get_padding(kernel_size, 1),
|
293 |
+
)
|
294 |
+
),
|
295 |
+
]
|
296 |
+
)
|
297 |
+
self.convs2.apply(init_weights)
|
298 |
+
|
299 |
+
def forward(self, x, x_mask=None):
|
300 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
301 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
302 |
+
if x_mask is not None:
|
303 |
+
xt = xt * x_mask
|
304 |
+
xt = c1(xt)
|
305 |
+
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
306 |
+
if x_mask is not None:
|
307 |
+
xt = xt * x_mask
|
308 |
+
xt = c2(xt)
|
309 |
+
x = xt + x
|
310 |
+
if x_mask is not None:
|
311 |
+
x = x * x_mask
|
312 |
+
return x
|
313 |
+
|
314 |
+
def remove_weight_norm(self):
|
315 |
+
for l in self.convs1:
|
316 |
+
remove_weight_norm(l)
|
317 |
+
for l in self.convs2:
|
318 |
+
remove_weight_norm(l)
|
319 |
+
|
320 |
+
|
321 |
+
class ResBlock2(torch.nn.Module):
|
322 |
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
323 |
+
super(ResBlock2, self).__init__()
|
324 |
+
self.convs = nn.ModuleList(
|
325 |
+
[
|
326 |
+
weight_norm(
|
327 |
+
Conv1d(
|
328 |
+
channels,
|
329 |
+
channels,
|
330 |
+
kernel_size,
|
331 |
+
1,
|
332 |
+
dilation=dilation[0],
|
333 |
+
padding=get_padding(kernel_size, dilation[0]),
|
334 |
+
)
|
335 |
+
),
|
336 |
+
weight_norm(
|
337 |
+
Conv1d(
|
338 |
+
channels,
|
339 |
+
channels,
|
340 |
+
kernel_size,
|
341 |
+
1,
|
342 |
+
dilation=dilation[1],
|
343 |
+
padding=get_padding(kernel_size, dilation[1]),
|
344 |
+
)
|
345 |
+
),
|
346 |
+
]
|
347 |
+
)
|
348 |
+
self.convs.apply(init_weights)
|
349 |
+
|
350 |
+
def forward(self, x, x_mask=None):
|
351 |
+
for c in self.convs:
|
352 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
353 |
+
if x_mask is not None:
|
354 |
+
xt = xt * x_mask
|
355 |
+
xt = c(xt)
|
356 |
+
x = xt + x
|
357 |
+
if x_mask is not None:
|
358 |
+
x = x * x_mask
|
359 |
+
return x
|
360 |
+
|
361 |
+
def remove_weight_norm(self):
|
362 |
+
for l in self.convs:
|
363 |
+
remove_weight_norm(l)
|
364 |
+
|
365 |
+
|
366 |
+
class Log(nn.Module):
|
367 |
+
def forward(self, x, x_mask, reverse=False, **kwargs):
|
368 |
+
if not reverse:
|
369 |
+
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
|
370 |
+
logdet = torch.sum(-y, [1, 2])
|
371 |
+
return y, logdet
|
372 |
+
else:
|
373 |
+
x = torch.exp(x) * x_mask
|
374 |
+
return x
|
375 |
+
|
376 |
+
|
377 |
+
class Flip(nn.Module):
|
378 |
+
def forward(self, x, *args, reverse=False, **kwargs):
|
379 |
+
x = torch.flip(x, [1])
|
380 |
+
if not reverse:
|
381 |
+
logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
|
382 |
+
return x, logdet
|
383 |
+
else:
|
384 |
+
return x
|
385 |
+
|
386 |
+
|
387 |
+
class ElementwiseAffine(nn.Module):
|
388 |
+
def __init__(self, channels):
|
389 |
+
super().__init__()
|
390 |
+
self.channels = channels
|
391 |
+
self.m = nn.Parameter(torch.zeros(channels, 1))
|
392 |
+
self.logs = nn.Parameter(torch.zeros(channels, 1))
|
393 |
+
|
394 |
+
def forward(self, x, x_mask, reverse=False, **kwargs):
|
395 |
+
if not reverse:
|
396 |
+
y = self.m + torch.exp(self.logs) * x
|
397 |
+
y = y * x_mask
|
398 |
+
logdet = torch.sum(self.logs * x_mask, [1, 2])
|
399 |
+
return y, logdet
|
400 |
+
else:
|
401 |
+
x = (x - self.m) * torch.exp(-self.logs) * x_mask
|
402 |
+
return x
|
403 |
+
|
404 |
+
|
405 |
+
class ResidualCouplingLayer(nn.Module):
|
406 |
+
def __init__(
|
407 |
+
self,
|
408 |
+
channels,
|
409 |
+
hidden_channels,
|
410 |
+
kernel_size,
|
411 |
+
dilation_rate,
|
412 |
+
n_layers,
|
413 |
+
p_dropout=0,
|
414 |
+
gin_channels=0,
|
415 |
+
mean_only=False,
|
416 |
+
):
|
417 |
+
assert channels % 2 == 0, "channels should be divisible by 2"
|
418 |
+
super().__init__()
|
419 |
+
self.channels = channels
|
420 |
+
self.hidden_channels = hidden_channels
|
421 |
+
self.kernel_size = kernel_size
|
422 |
+
self.dilation_rate = dilation_rate
|
423 |
+
self.n_layers = n_layers
|
424 |
+
self.half_channels = channels // 2
|
425 |
+
self.mean_only = mean_only
|
426 |
+
|
427 |
+
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
|
428 |
+
self.enc = WN(
|
429 |
+
hidden_channels,
|
430 |
+
kernel_size,
|
431 |
+
dilation_rate,
|
432 |
+
n_layers,
|
433 |
+
p_dropout=p_dropout,
|
434 |
+
gin_channels=gin_channels,
|
435 |
+
)
|
436 |
+
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
437 |
+
self.post.weight.data.zero_()
|
438 |
+
self.post.bias.data.zero_()
|
439 |
+
|
440 |
+
def forward(self, x, x_mask, g=None, reverse=False):
|
441 |
+
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
442 |
+
h = self.pre(x0) * x_mask
|
443 |
+
h = self.enc(h, x_mask, g=g)
|
444 |
+
stats = self.post(h) * x_mask
|
445 |
+
if not self.mean_only:
|
446 |
+
m, logs = torch.split(stats, [self.half_channels] * 2, 1)
|
447 |
+
else:
|
448 |
+
m = stats
|
449 |
+
logs = torch.zeros_like(m)
|
450 |
+
|
451 |
+
if not reverse:
|
452 |
+
x1 = m + x1 * torch.exp(logs) * x_mask
|
453 |
+
x = torch.cat([x0, x1], 1)
|
454 |
+
logdet = torch.sum(logs, [1, 2])
|
455 |
+
return x, logdet
|
456 |
+
else:
|
457 |
+
x1 = (x1 - m) * torch.exp(-logs) * x_mask
|
458 |
+
x = torch.cat([x0, x1], 1)
|
459 |
+
return x
|
460 |
+
|
461 |
+
def remove_weight_norm(self):
|
462 |
+
self.enc.remove_weight_norm()
|
463 |
+
|
464 |
+
|
465 |
+
class ConvFlow(nn.Module):
|
466 |
+
def __init__(
|
467 |
+
self,
|
468 |
+
in_channels,
|
469 |
+
filter_channels,
|
470 |
+
kernel_size,
|
471 |
+
n_layers,
|
472 |
+
num_bins=10,
|
473 |
+
tail_bound=5.0,
|
474 |
+
):
|
475 |
+
super().__init__()
|
476 |
+
self.in_channels = in_channels
|
477 |
+
self.filter_channels = filter_channels
|
478 |
+
self.kernel_size = kernel_size
|
479 |
+
self.n_layers = n_layers
|
480 |
+
self.num_bins = num_bins
|
481 |
+
self.tail_bound = tail_bound
|
482 |
+
self.half_channels = in_channels // 2
|
483 |
+
|
484 |
+
self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
|
485 |
+
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
|
486 |
+
self.proj = nn.Conv1d(
|
487 |
+
filter_channels, self.half_channels * (num_bins * 3 - 1), 1
|
488 |
+
)
|
489 |
+
self.proj.weight.data.zero_()
|
490 |
+
self.proj.bias.data.zero_()
|
491 |
+
|
492 |
+
def forward(self, x, x_mask, g=None, reverse=False):
|
493 |
+
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
494 |
+
h = self.pre(x0)
|
495 |
+
h = self.convs(h, x_mask, g=g)
|
496 |
+
h = self.proj(h) * x_mask
|
497 |
+
|
498 |
+
b, c, t = x0.shape
|
499 |
+
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
|
500 |
+
|
501 |
+
unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
|
502 |
+
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
|
503 |
+
self.filter_channels
|
504 |
+
)
|
505 |
+
unnormalized_derivatives = h[..., 2 * self.num_bins :]
|
506 |
+
|
507 |
+
x1, logabsdet = piecewise_rational_quadratic_transform(
|
508 |
+
x1,
|
509 |
+
unnormalized_widths,
|
510 |
+
unnormalized_heights,
|
511 |
+
unnormalized_derivatives,
|
512 |
+
inverse=reverse,
|
513 |
+
tails="linear",
|
514 |
+
tail_bound=self.tail_bound,
|
515 |
+
)
|
516 |
+
|
517 |
+
x = torch.cat([x0, x1], 1) * x_mask
|
518 |
+
logdet = torch.sum(logabsdet * x_mask, [1, 2])
|
519 |
+
if not reverse:
|
520 |
+
return x, logdet
|
521 |
+
else:
|
522 |
+
return x
|
lib/infer_pack/transforms.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.nn import functional as F
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
DEFAULT_MIN_BIN_WIDTH = 1e-3
|
8 |
+
DEFAULT_MIN_BIN_HEIGHT = 1e-3
|
9 |
+
DEFAULT_MIN_DERIVATIVE = 1e-3
|
10 |
+
|
11 |
+
|
12 |
+
def piecewise_rational_quadratic_transform(
|
13 |
+
inputs,
|
14 |
+
unnormalized_widths,
|
15 |
+
unnormalized_heights,
|
16 |
+
unnormalized_derivatives,
|
17 |
+
inverse=False,
|
18 |
+
tails=None,
|
19 |
+
tail_bound=1.0,
|
20 |
+
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
21 |
+
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
22 |
+
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
23 |
+
):
|
24 |
+
if tails is None:
|
25 |
+
spline_fn = rational_quadratic_spline
|
26 |
+
spline_kwargs = {}
|
27 |
+
else:
|
28 |
+
spline_fn = unconstrained_rational_quadratic_spline
|
29 |
+
spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
|
30 |
+
|
31 |
+
outputs, logabsdet = spline_fn(
|
32 |
+
inputs=inputs,
|
33 |
+
unnormalized_widths=unnormalized_widths,
|
34 |
+
unnormalized_heights=unnormalized_heights,
|
35 |
+
unnormalized_derivatives=unnormalized_derivatives,
|
36 |
+
inverse=inverse,
|
37 |
+
min_bin_width=min_bin_width,
|
38 |
+
min_bin_height=min_bin_height,
|
39 |
+
min_derivative=min_derivative,
|
40 |
+
**spline_kwargs
|
41 |
+
)
|
42 |
+
return outputs, logabsdet
|
43 |
+
|
44 |
+
|
45 |
+
def searchsorted(bin_locations, inputs, eps=1e-6):
|
46 |
+
bin_locations[..., -1] += eps
|
47 |
+
return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
|
48 |
+
|
49 |
+
|
50 |
+
def unconstrained_rational_quadratic_spline(
|
51 |
+
inputs,
|
52 |
+
unnormalized_widths,
|
53 |
+
unnormalized_heights,
|
54 |
+
unnormalized_derivatives,
|
55 |
+
inverse=False,
|
56 |
+
tails="linear",
|
57 |
+
tail_bound=1.0,
|
58 |
+
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
59 |
+
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
60 |
+
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
61 |
+
):
|
62 |
+
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
|
63 |
+
outside_interval_mask = ~inside_interval_mask
|
64 |
+
|
65 |
+
outputs = torch.zeros_like(inputs)
|
66 |
+
logabsdet = torch.zeros_like(inputs)
|
67 |
+
|
68 |
+
if tails == "linear":
|
69 |
+
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
|
70 |
+
constant = np.log(np.exp(1 - min_derivative) - 1)
|
71 |
+
unnormalized_derivatives[..., 0] = constant
|
72 |
+
unnormalized_derivatives[..., -1] = constant
|
73 |
+
|
74 |
+
outputs[outside_interval_mask] = inputs[outside_interval_mask]
|
75 |
+
logabsdet[outside_interval_mask] = 0
|
76 |
+
else:
|
77 |
+
raise RuntimeError("{} tails are not implemented.".format(tails))
|
78 |
+
|
79 |
+
(
|
80 |
+
outputs[inside_interval_mask],
|
81 |
+
logabsdet[inside_interval_mask],
|
82 |
+
) = rational_quadratic_spline(
|
83 |
+
inputs=inputs[inside_interval_mask],
|
84 |
+
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
|
85 |
+
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
|
86 |
+
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
|
87 |
+
inverse=inverse,
|
88 |
+
left=-tail_bound,
|
89 |
+
right=tail_bound,
|
90 |
+
bottom=-tail_bound,
|
91 |
+
top=tail_bound,
|
92 |
+
min_bin_width=min_bin_width,
|
93 |
+
min_bin_height=min_bin_height,
|
94 |
+
min_derivative=min_derivative,
|
95 |
+
)
|
96 |
+
|
97 |
+
return outputs, logabsdet
|
98 |
+
|
99 |
+
|
100 |
+
def rational_quadratic_spline(
|
101 |
+
inputs,
|
102 |
+
unnormalized_widths,
|
103 |
+
unnormalized_heights,
|
104 |
+
unnormalized_derivatives,
|
105 |
+
inverse=False,
|
106 |
+
left=0.0,
|
107 |
+
right=1.0,
|
108 |
+
bottom=0.0,
|
109 |
+
top=1.0,
|
110 |
+
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
111 |
+
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
112 |
+
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
113 |
+
):
|
114 |
+
if torch.min(inputs) < left or torch.max(inputs) > right:
|
115 |
+
raise ValueError("Input to a transform is not within its domain")
|
116 |
+
|
117 |
+
num_bins = unnormalized_widths.shape[-1]
|
118 |
+
|
119 |
+
if min_bin_width * num_bins > 1.0:
|
120 |
+
raise ValueError("Minimal bin width too large for the number of bins")
|
121 |
+
if min_bin_height * num_bins > 1.0:
|
122 |
+
raise ValueError("Minimal bin height too large for the number of bins")
|
123 |
+
|
124 |
+
widths = F.softmax(unnormalized_widths, dim=-1)
|
125 |
+
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
|
126 |
+
cumwidths = torch.cumsum(widths, dim=-1)
|
127 |
+
cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
|
128 |
+
cumwidths = (right - left) * cumwidths + left
|
129 |
+
cumwidths[..., 0] = left
|
130 |
+
cumwidths[..., -1] = right
|
131 |
+
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
|
132 |
+
|
133 |
+
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
|
134 |
+
|
135 |
+
heights = F.softmax(unnormalized_heights, dim=-1)
|
136 |
+
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
|
137 |
+
cumheights = torch.cumsum(heights, dim=-1)
|
138 |
+
cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
|
139 |
+
cumheights = (top - bottom) * cumheights + bottom
|
140 |
+
cumheights[..., 0] = bottom
|
141 |
+
cumheights[..., -1] = top
|
142 |
+
heights = cumheights[..., 1:] - cumheights[..., :-1]
|
143 |
+
|
144 |
+
if inverse:
|
145 |
+
bin_idx = searchsorted(cumheights, inputs)[..., None]
|
146 |
+
else:
|
147 |
+
bin_idx = searchsorted(cumwidths, inputs)[..., None]
|
148 |
+
|
149 |
+
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
|
150 |
+
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
|
151 |
+
|
152 |
+
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
|
153 |
+
delta = heights / widths
|
154 |
+
input_delta = delta.gather(-1, bin_idx)[..., 0]
|
155 |
+
|
156 |
+
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
|
157 |
+
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
|
158 |
+
|
159 |
+
input_heights = heights.gather(-1, bin_idx)[..., 0]
|
160 |
+
|
161 |
+
if inverse:
|
162 |
+
a = (inputs - input_cumheights) * (
|
163 |
+
input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
164 |
+
) + input_heights * (input_delta - input_derivatives)
|
165 |
+
b = input_heights * input_derivatives - (inputs - input_cumheights) * (
|
166 |
+
input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
167 |
+
)
|
168 |
+
c = -input_delta * (inputs - input_cumheights)
|
169 |
+
|
170 |
+
discriminant = b.pow(2) - 4 * a * c
|
171 |
+
assert (discriminant >= 0).all()
|
172 |
+
|
173 |
+
root = (2 * c) / (-b - torch.sqrt(discriminant))
|
174 |
+
outputs = root * input_bin_widths + input_cumwidths
|
175 |
+
|
176 |
+
theta_one_minus_theta = root * (1 - root)
|
177 |
+
denominator = input_delta + (
|
178 |
+
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
179 |
+
* theta_one_minus_theta
|
180 |
+
)
|
181 |
+
derivative_numerator = input_delta.pow(2) * (
|
182 |
+
input_derivatives_plus_one * root.pow(2)
|
183 |
+
+ 2 * input_delta * theta_one_minus_theta
|
184 |
+
+ input_derivatives * (1 - root).pow(2)
|
185 |
+
)
|
186 |
+
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
187 |
+
|
188 |
+
return outputs, -logabsdet
|
189 |
+
else:
|
190 |
+
theta = (inputs - input_cumwidths) / input_bin_widths
|
191 |
+
theta_one_minus_theta = theta * (1 - theta)
|
192 |
+
|
193 |
+
numerator = input_heights * (
|
194 |
+
input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
|
195 |
+
)
|
196 |
+
denominator = input_delta + (
|
197 |
+
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
198 |
+
* theta_one_minus_theta
|
199 |
+
)
|
200 |
+
outputs = input_cumheights + numerator / denominator
|
201 |
+
|
202 |
+
derivative_numerator = input_delta.pow(2) * (
|
203 |
+
input_derivatives_plus_one * theta.pow(2)
|
204 |
+
+ 2 * input_delta * theta_one_minus_theta
|
205 |
+
+ input_derivatives * (1 - theta).pow(2)
|
206 |
+
)
|
207 |
+
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
208 |
+
|
209 |
+
return outputs, logabsdet
|
lib/rmvpe.py
ADDED
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch, numpy as np
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
class BiGRU(nn.Module):
|
8 |
+
def __init__(self, input_features, hidden_features, num_layers):
|
9 |
+
super(BiGRU, self).__init__()
|
10 |
+
self.gru = nn.GRU(
|
11 |
+
input_features,
|
12 |
+
hidden_features,
|
13 |
+
num_layers=num_layers,
|
14 |
+
batch_first=True,
|
15 |
+
bidirectional=True,
|
16 |
+
)
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
return self.gru(x)[0]
|
20 |
+
|
21 |
+
|
22 |
+
class ConvBlockRes(nn.Module):
|
23 |
+
def __init__(self, in_channels, out_channels, momentum=0.01):
|
24 |
+
super(ConvBlockRes, self).__init__()
|
25 |
+
self.conv = nn.Sequential(
|
26 |
+
nn.Conv2d(
|
27 |
+
in_channels=in_channels,
|
28 |
+
out_channels=out_channels,
|
29 |
+
kernel_size=(3, 3),
|
30 |
+
stride=(1, 1),
|
31 |
+
padding=(1, 1),
|
32 |
+
bias=False,
|
33 |
+
),
|
34 |
+
nn.BatchNorm2d(out_channels, momentum=momentum),
|
35 |
+
nn.ReLU(),
|
36 |
+
nn.Conv2d(
|
37 |
+
in_channels=out_channels,
|
38 |
+
out_channels=out_channels,
|
39 |
+
kernel_size=(3, 3),
|
40 |
+
stride=(1, 1),
|
41 |
+
padding=(1, 1),
|
42 |
+
bias=False,
|
43 |
+
),
|
44 |
+
nn.BatchNorm2d(out_channels, momentum=momentum),
|
45 |
+
nn.ReLU(),
|
46 |
+
)
|
47 |
+
if in_channels != out_channels:
|
48 |
+
self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
|
49 |
+
self.is_shortcut = True
|
50 |
+
else:
|
51 |
+
self.is_shortcut = False
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
if self.is_shortcut:
|
55 |
+
return self.conv(x) + self.shortcut(x)
|
56 |
+
else:
|
57 |
+
return self.conv(x) + x
|
58 |
+
|
59 |
+
|
60 |
+
class Encoder(nn.Module):
|
61 |
+
def __init__(
|
62 |
+
self,
|
63 |
+
in_channels,
|
64 |
+
in_size,
|
65 |
+
n_encoders,
|
66 |
+
kernel_size,
|
67 |
+
n_blocks,
|
68 |
+
out_channels=16,
|
69 |
+
momentum=0.01,
|
70 |
+
):
|
71 |
+
super(Encoder, self).__init__()
|
72 |
+
self.n_encoders = n_encoders
|
73 |
+
self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
|
74 |
+
self.layers = nn.ModuleList()
|
75 |
+
self.latent_channels = []
|
76 |
+
for i in range(self.n_encoders):
|
77 |
+
self.layers.append(
|
78 |
+
ResEncoderBlock(
|
79 |
+
in_channels, out_channels, kernel_size, n_blocks, momentum=momentum
|
80 |
+
)
|
81 |
+
)
|
82 |
+
self.latent_channels.append([out_channels, in_size])
|
83 |
+
in_channels = out_channels
|
84 |
+
out_channels *= 2
|
85 |
+
in_size //= 2
|
86 |
+
self.out_size = in_size
|
87 |
+
self.out_channel = out_channels
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
concat_tensors = []
|
91 |
+
x = self.bn(x)
|
92 |
+
for i in range(self.n_encoders):
|
93 |
+
_, x = self.layers[i](x)
|
94 |
+
concat_tensors.append(_)
|
95 |
+
return x, concat_tensors
|
96 |
+
|
97 |
+
|
98 |
+
class ResEncoderBlock(nn.Module):
|
99 |
+
def __init__(
|
100 |
+
self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01
|
101 |
+
):
|
102 |
+
super(ResEncoderBlock, self).__init__()
|
103 |
+
self.n_blocks = n_blocks
|
104 |
+
self.conv = nn.ModuleList()
|
105 |
+
self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
|
106 |
+
for i in range(n_blocks - 1):
|
107 |
+
self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
|
108 |
+
self.kernel_size = kernel_size
|
109 |
+
if self.kernel_size is not None:
|
110 |
+
self.pool = nn.AvgPool2d(kernel_size=kernel_size)
|
111 |
+
|
112 |
+
def forward(self, x):
|
113 |
+
for i in range(self.n_blocks):
|
114 |
+
x = self.conv[i](x)
|
115 |
+
if self.kernel_size is not None:
|
116 |
+
return x, self.pool(x)
|
117 |
+
else:
|
118 |
+
return x
|
119 |
+
|
120 |
+
|
121 |
+
class Intermediate(nn.Module): #
|
122 |
+
def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
|
123 |
+
super(Intermediate, self).__init__()
|
124 |
+
self.n_inters = n_inters
|
125 |
+
self.layers = nn.ModuleList()
|
126 |
+
self.layers.append(
|
127 |
+
ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum)
|
128 |
+
)
|
129 |
+
for i in range(self.n_inters - 1):
|
130 |
+
self.layers.append(
|
131 |
+
ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum)
|
132 |
+
)
|
133 |
+
|
134 |
+
def forward(self, x):
|
135 |
+
for i in range(self.n_inters):
|
136 |
+
x = self.layers[i](x)
|
137 |
+
return x
|
138 |
+
|
139 |
+
|
140 |
+
class ResDecoderBlock(nn.Module):
|
141 |
+
def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
|
142 |
+
super(ResDecoderBlock, self).__init__()
|
143 |
+
out_padding = (0, 1) if stride == (1, 2) else (1, 1)
|
144 |
+
self.n_blocks = n_blocks
|
145 |
+
self.conv1 = nn.Sequential(
|
146 |
+
nn.ConvTranspose2d(
|
147 |
+
in_channels=in_channels,
|
148 |
+
out_channels=out_channels,
|
149 |
+
kernel_size=(3, 3),
|
150 |
+
stride=stride,
|
151 |
+
padding=(1, 1),
|
152 |
+
output_padding=out_padding,
|
153 |
+
bias=False,
|
154 |
+
),
|
155 |
+
nn.BatchNorm2d(out_channels, momentum=momentum),
|
156 |
+
nn.ReLU(),
|
157 |
+
)
|
158 |
+
self.conv2 = nn.ModuleList()
|
159 |
+
self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
|
160 |
+
for i in range(n_blocks - 1):
|
161 |
+
self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
|
162 |
+
|
163 |
+
def forward(self, x, concat_tensor):
|
164 |
+
x = self.conv1(x)
|
165 |
+
x = torch.cat((x, concat_tensor), dim=1)
|
166 |
+
for i in range(self.n_blocks):
|
167 |
+
x = self.conv2[i](x)
|
168 |
+
return x
|
169 |
+
|
170 |
+
|
171 |
+
class Decoder(nn.Module):
|
172 |
+
def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
|
173 |
+
super(Decoder, self).__init__()
|
174 |
+
self.layers = nn.ModuleList()
|
175 |
+
self.n_decoders = n_decoders
|
176 |
+
for i in range(self.n_decoders):
|
177 |
+
out_channels = in_channels // 2
|
178 |
+
self.layers.append(
|
179 |
+
ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)
|
180 |
+
)
|
181 |
+
in_channels = out_channels
|
182 |
+
|
183 |
+
def forward(self, x, concat_tensors):
|
184 |
+
for i in range(self.n_decoders):
|
185 |
+
x = self.layers[i](x, concat_tensors[-1 - i])
|
186 |
+
return x
|
187 |
+
|
188 |
+
|
189 |
+
class DeepUnet(nn.Module):
|
190 |
+
def __init__(
|
191 |
+
self,
|
192 |
+
kernel_size,
|
193 |
+
n_blocks,
|
194 |
+
en_de_layers=5,
|
195 |
+
inter_layers=4,
|
196 |
+
in_channels=1,
|
197 |
+
en_out_channels=16,
|
198 |
+
):
|
199 |
+
super(DeepUnet, self).__init__()
|
200 |
+
self.encoder = Encoder(
|
201 |
+
in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels
|
202 |
+
)
|
203 |
+
self.intermediate = Intermediate(
|
204 |
+
self.encoder.out_channel // 2,
|
205 |
+
self.encoder.out_channel,
|
206 |
+
inter_layers,
|
207 |
+
n_blocks,
|
208 |
+
)
|
209 |
+
self.decoder = Decoder(
|
210 |
+
self.encoder.out_channel, en_de_layers, kernel_size, n_blocks
|
211 |
+
)
|
212 |
+
|
213 |
+
def forward(self, x):
|
214 |
+
x, concat_tensors = self.encoder(x)
|
215 |
+
x = self.intermediate(x)
|
216 |
+
x = self.decoder(x, concat_tensors)
|
217 |
+
return x
|
218 |
+
|
219 |
+
|
220 |
+
class E2E(nn.Module):
|
221 |
+
def __init__(
|
222 |
+
self,
|
223 |
+
n_blocks,
|
224 |
+
n_gru,
|
225 |
+
kernel_size,
|
226 |
+
en_de_layers=5,
|
227 |
+
inter_layers=4,
|
228 |
+
in_channels=1,
|
229 |
+
en_out_channels=16,
|
230 |
+
):
|
231 |
+
super(E2E, self).__init__()
|
232 |
+
self.unet = DeepUnet(
|
233 |
+
kernel_size,
|
234 |
+
n_blocks,
|
235 |
+
en_de_layers,
|
236 |
+
inter_layers,
|
237 |
+
in_channels,
|
238 |
+
en_out_channels,
|
239 |
+
)
|
240 |
+
self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
|
241 |
+
if n_gru:
|
242 |
+
self.fc = nn.Sequential(
|
243 |
+
BiGRU(3 * 128, 256, n_gru),
|
244 |
+
nn.Linear(512, 360),
|
245 |
+
nn.Dropout(0.25),
|
246 |
+
nn.Sigmoid(),
|
247 |
+
)
|
248 |
+
else:
|
249 |
+
self.fc = nn.Sequential(
|
250 |
+
nn.Linear(3 * nn.N_MELS, nn.N_CLASS), nn.Dropout(0.25), nn.Sigmoid()
|
251 |
+
)
|
252 |
+
|
253 |
+
def forward(self, mel):
|
254 |
+
mel = mel.transpose(-1, -2).unsqueeze(1)
|
255 |
+
x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
|
256 |
+
x = self.fc(x)
|
257 |
+
return x
|
258 |
+
|
259 |
+
|
260 |
+
from librosa.filters import mel
|
261 |
+
|
262 |
+
|
263 |
+
class MelSpectrogram(torch.nn.Module):
|
264 |
+
def __init__(
|
265 |
+
self,
|
266 |
+
is_half,
|
267 |
+
n_mel_channels,
|
268 |
+
sampling_rate,
|
269 |
+
win_length,
|
270 |
+
hop_length,
|
271 |
+
n_fft=None,
|
272 |
+
mel_fmin=0,
|
273 |
+
mel_fmax=None,
|
274 |
+
clamp=1e-5,
|
275 |
+
):
|
276 |
+
super().__init__()
|
277 |
+
n_fft = win_length if n_fft is None else n_fft
|
278 |
+
self.hann_window = {}
|
279 |
+
mel_basis = mel(
|
280 |
+
sr=sampling_rate,
|
281 |
+
n_fft=n_fft,
|
282 |
+
n_mels=n_mel_channels,
|
283 |
+
fmin=mel_fmin,
|
284 |
+
fmax=mel_fmax,
|
285 |
+
htk=True,
|
286 |
+
)
|
287 |
+
mel_basis = torch.from_numpy(mel_basis).float()
|
288 |
+
self.register_buffer("mel_basis", mel_basis)
|
289 |
+
self.n_fft = win_length if n_fft is None else n_fft
|
290 |
+
self.hop_length = hop_length
|
291 |
+
self.win_length = win_length
|
292 |
+
self.sampling_rate = sampling_rate
|
293 |
+
self.n_mel_channels = n_mel_channels
|
294 |
+
self.clamp = clamp
|
295 |
+
self.is_half = is_half
|
296 |
+
|
297 |
+
def forward(self, audio, keyshift=0, speed=1, center=True):
|
298 |
+
factor = 2 ** (keyshift / 12)
|
299 |
+
n_fft_new = int(np.round(self.n_fft * factor))
|
300 |
+
win_length_new = int(np.round(self.win_length * factor))
|
301 |
+
hop_length_new = int(np.round(self.hop_length * speed))
|
302 |
+
keyshift_key = str(keyshift) + "_" + str(audio.device)
|
303 |
+
if keyshift_key not in self.hann_window:
|
304 |
+
self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(
|
305 |
+
audio.device
|
306 |
+
)
|
307 |
+
fft = torch.stft(
|
308 |
+
audio,
|
309 |
+
n_fft=n_fft_new,
|
310 |
+
hop_length=hop_length_new,
|
311 |
+
win_length=win_length_new,
|
312 |
+
window=self.hann_window[keyshift_key],
|
313 |
+
center=center,
|
314 |
+
return_complex=True,
|
315 |
+
)
|
316 |
+
magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
|
317 |
+
if keyshift != 0:
|
318 |
+
size = self.n_fft // 2 + 1
|
319 |
+
resize = magnitude.size(1)
|
320 |
+
if resize < size:
|
321 |
+
magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
|
322 |
+
magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
|
323 |
+
mel_output = torch.matmul(self.mel_basis, magnitude)
|
324 |
+
if self.is_half == True:
|
325 |
+
mel_output = mel_output.half()
|
326 |
+
log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
|
327 |
+
return log_mel_spec
|
328 |
+
|
329 |
+
|
330 |
+
class RMVPE:
|
331 |
+
def __init__(self, model_path, is_half, device=None):
|
332 |
+
self.resample_kernel = {}
|
333 |
+
model = E2E(4, 1, (2, 2))
|
334 |
+
ckpt = torch.load(model_path, map_location="cpu")
|
335 |
+
model.load_state_dict(ckpt)
|
336 |
+
model.eval()
|
337 |
+
if is_half == True:
|
338 |
+
model = model.half()
|
339 |
+
self.model = model
|
340 |
+
self.resample_kernel = {}
|
341 |
+
self.is_half = is_half
|
342 |
+
if device is None:
|
343 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
344 |
+
self.device = device
|
345 |
+
self.mel_extractor = MelSpectrogram(
|
346 |
+
is_half, 128, 16000, 1024, 160, None, 30, 8000
|
347 |
+
).to(device)
|
348 |
+
self.model = self.model.to(device)
|
349 |
+
cents_mapping = 20 * np.arange(360) + 1997.3794084376191
|
350 |
+
self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368
|
351 |
+
|
352 |
+
def mel2hidden(self, mel):
|
353 |
+
with torch.no_grad():
|
354 |
+
n_frames = mel.shape[-1]
|
355 |
+
mel = F.pad(
|
356 |
+
mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="reflect"
|
357 |
+
)
|
358 |
+
hidden = self.model(mel)
|
359 |
+
return hidden[:, :n_frames]
|
360 |
+
|
361 |
+
def decode(self, hidden, thred=0.03):
|
362 |
+
cents_pred = self.to_local_average_cents(hidden, thred=thred)
|
363 |
+
f0 = 10 * (2 ** (cents_pred / 1200))
|
364 |
+
f0[f0 == 10] = 0
|
365 |
+
# f0 = np.array([10 * (2 ** (cent_pred / 1200)) if cent_pred else 0 for cent_pred in cents_pred])
|
366 |
+
return f0
|
367 |
+
|
368 |
+
def infer_from_audio(self, audio, thred=0.03):
|
369 |
+
audio = torch.from_numpy(audio).float().to(self.device).unsqueeze(0)
|
370 |
+
# torch.cuda.synchronize()
|
371 |
+
# t0=ttime()
|
372 |
+
mel = self.mel_extractor(audio, center=True)
|
373 |
+
# torch.cuda.synchronize()
|
374 |
+
# t1=ttime()
|
375 |
+
hidden = self.mel2hidden(mel)
|
376 |
+
# torch.cuda.synchronize()
|
377 |
+
# t2=ttime()
|
378 |
+
hidden = hidden.squeeze(0).cpu().numpy()
|
379 |
+
if self.is_half == True:
|
380 |
+
hidden = hidden.astype("float32")
|
381 |
+
f0 = self.decode(hidden, thred=thred)
|
382 |
+
# torch.cuda.synchronize()
|
383 |
+
# t3=ttime()
|
384 |
+
# print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0))
|
385 |
+
return f0
|
386 |
+
|
387 |
+
def pitch_based_audio_inference(self, audio, thred=0.03, f0_min=50, f0_max=1100):
|
388 |
+
audio = torch.from_numpy(audio).float().to(self.device).unsqueeze(0)
|
389 |
+
mel = self.mel_extractor(audio, center=True)
|
390 |
+
hidden = self.mel2hidden(mel)
|
391 |
+
hidden = hidden.squeeze(0).cpu().numpy()
|
392 |
+
if self.is_half == True:
|
393 |
+
hidden = hidden.astype("float32")
|
394 |
+
f0 = self.decode(hidden, thred=thred)
|
395 |
+
f0[(f0 < f0_min) | (f0 > f0_max)] = 0
|
396 |
+
return f0
|
397 |
+
|
398 |
+
def to_local_average_cents(self, salience, thred=0.05):
|
399 |
+
# t0 = ttime()
|
400 |
+
center = np.argmax(salience, axis=1) # frame length#index
|
401 |
+
salience = np.pad(salience, ((0, 0), (4, 4))) # frame length,368
|
402 |
+
# t1 = ttime()
|
403 |
+
center += 4
|
404 |
+
todo_salience = []
|
405 |
+
todo_cents_mapping = []
|
406 |
+
starts = center - 4
|
407 |
+
ends = center + 5
|
408 |
+
for idx in range(salience.shape[0]):
|
409 |
+
todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
|
410 |
+
todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
|
411 |
+
# t2 = ttime()
|
412 |
+
todo_salience = np.array(todo_salience) # frame length,9
|
413 |
+
todo_cents_mapping = np.array(todo_cents_mapping) # frame length,9
|
414 |
+
product_sum = np.sum(todo_salience * todo_cents_mapping, 1)
|
415 |
+
weight_sum = np.sum(todo_salience, 1) # frame length
|
416 |
+
devided = product_sum / weight_sum # frame length
|
417 |
+
# t3 = ttime()
|
418 |
+
maxx = np.max(salience, axis=1) # frame length
|
419 |
+
devided[maxx <= thred] = 0
|
420 |
+
# t4 = ttime()
|
421 |
+
# print("decode:%s\t%s\t%s\t%s" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
|
422 |
+
return devided
|
mdx_models/data.json
ADDED
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"0ddfc0eb5792638ad5dc27850236c246": {
|
3 |
+
"compensate": 1.035,
|
4 |
+
"mdx_dim_f_set": 2048,
|
5 |
+
"mdx_dim_t_set": 8,
|
6 |
+
"mdx_n_fft_scale_set": 6144,
|
7 |
+
"primary_stem": "Vocals"
|
8 |
+
},
|
9 |
+
"26d308f91f3423a67dc69a6d12a8793d": {
|
10 |
+
"compensate": 1.035,
|
11 |
+
"mdx_dim_f_set": 2048,
|
12 |
+
"mdx_dim_t_set": 9,
|
13 |
+
"mdx_n_fft_scale_set": 8192,
|
14 |
+
"primary_stem": "Other"
|
15 |
+
},
|
16 |
+
"2cdd429caac38f0194b133884160f2c6": {
|
17 |
+
"compensate": 1.045,
|
18 |
+
"mdx_dim_f_set": 3072,
|
19 |
+
"mdx_dim_t_set": 8,
|
20 |
+
"mdx_n_fft_scale_set": 7680,
|
21 |
+
"primary_stem": "Instrumental"
|
22 |
+
},
|
23 |
+
"2f5501189a2f6db6349916fabe8c90de": {
|
24 |
+
"compensate": 1.035,
|
25 |
+
"mdx_dim_f_set": 2048,
|
26 |
+
"mdx_dim_t_set": 8,
|
27 |
+
"mdx_n_fft_scale_set": 6144,
|
28 |
+
"primary_stem": "Vocals"
|
29 |
+
},
|
30 |
+
"398580b6d5d973af3120df54cee6759d": {
|
31 |
+
"compensate": 1.75,
|
32 |
+
"mdx_dim_f_set": 3072,
|
33 |
+
"mdx_dim_t_set": 8,
|
34 |
+
"mdx_n_fft_scale_set": 7680,
|
35 |
+
"primary_stem": "Vocals"
|
36 |
+
},
|
37 |
+
"488b3e6f8bd3717d9d7c428476be2d75": {
|
38 |
+
"compensate": 1.035,
|
39 |
+
"mdx_dim_f_set": 3072,
|
40 |
+
"mdx_dim_t_set": 8,
|
41 |
+
"mdx_n_fft_scale_set": 7680,
|
42 |
+
"primary_stem": "Instrumental"
|
43 |
+
},
|
44 |
+
"4910e7827f335048bdac11fa967772f9": {
|
45 |
+
"compensate": 1.035,
|
46 |
+
"mdx_dim_f_set": 2048,
|
47 |
+
"mdx_dim_t_set": 7,
|
48 |
+
"mdx_n_fft_scale_set": 4096,
|
49 |
+
"primary_stem": "Drums"
|
50 |
+
},
|
51 |
+
"53c4baf4d12c3e6c3831bb8f5b532b93": {
|
52 |
+
"compensate": 1.043,
|
53 |
+
"mdx_dim_f_set": 3072,
|
54 |
+
"mdx_dim_t_set": 8,
|
55 |
+
"mdx_n_fft_scale_set": 7680,
|
56 |
+
"primary_stem": "Vocals"
|
57 |
+
},
|
58 |
+
"5d343409ef0df48c7d78cce9f0106781": {
|
59 |
+
"compensate": 1.075,
|
60 |
+
"mdx_dim_f_set": 3072,
|
61 |
+
"mdx_dim_t_set": 8,
|
62 |
+
"mdx_n_fft_scale_set": 7680,
|
63 |
+
"primary_stem": "Vocals"
|
64 |
+
},
|
65 |
+
"5f6483271e1efb9bfb59e4a3e6d4d098": {
|
66 |
+
"compensate": 1.035,
|
67 |
+
"mdx_dim_f_set": 2048,
|
68 |
+
"mdx_dim_t_set": 9,
|
69 |
+
"mdx_n_fft_scale_set": 6144,
|
70 |
+
"primary_stem": "Vocals"
|
71 |
+
},
|
72 |
+
"65ab5919372a128e4167f5e01a8fda85": {
|
73 |
+
"compensate": 1.035,
|
74 |
+
"mdx_dim_f_set": 2048,
|
75 |
+
"mdx_dim_t_set": 8,
|
76 |
+
"mdx_n_fft_scale_set": 8192,
|
77 |
+
"primary_stem": "Other"
|
78 |
+
},
|
79 |
+
"6703e39f36f18aa7855ee1047765621d": {
|
80 |
+
"compensate": 1.035,
|
81 |
+
"mdx_dim_f_set": 2048,
|
82 |
+
"mdx_dim_t_set": 9,
|
83 |
+
"mdx_n_fft_scale_set": 16384,
|
84 |
+
"primary_stem": "Bass"
|
85 |
+
},
|
86 |
+
"6b31de20e84392859a3d09d43f089515": {
|
87 |
+
"compensate": 1.035,
|
88 |
+
"mdx_dim_f_set": 2048,
|
89 |
+
"mdx_dim_t_set": 8,
|
90 |
+
"mdx_n_fft_scale_set": 6144,
|
91 |
+
"primary_stem": "Vocals"
|
92 |
+
},
|
93 |
+
"867595e9de46f6ab699008295df62798": {
|
94 |
+
"compensate": 1.03,
|
95 |
+
"mdx_dim_f_set": 3072,
|
96 |
+
"mdx_dim_t_set": 8,
|
97 |
+
"mdx_n_fft_scale_set": 7680,
|
98 |
+
"primary_stem": "Vocals"
|
99 |
+
},
|
100 |
+
"a3cd63058945e777505c01d2507daf37": {
|
101 |
+
"compensate": 1.03,
|
102 |
+
"mdx_dim_f_set": 2048,
|
103 |
+
"mdx_dim_t_set": 8,
|
104 |
+
"mdx_n_fft_scale_set": 6144,
|
105 |
+
"primary_stem": "Vocals"
|
106 |
+
},
|
107 |
+
"b33d9b3950b6cbf5fe90a32608924700": {
|
108 |
+
"compensate": 1.03,
|
109 |
+
"mdx_dim_f_set": 3072,
|
110 |
+
"mdx_dim_t_set": 8,
|
111 |
+
"mdx_n_fft_scale_set": 7680,
|
112 |
+
"primary_stem": "Vocals"
|
113 |
+
},
|
114 |
+
"c3b29bdce8c4fa17ec609e16220330ab": {
|
115 |
+
"compensate": 1.035,
|
116 |
+
"mdx_dim_f_set": 2048,
|
117 |
+
"mdx_dim_t_set": 8,
|
118 |
+
"mdx_n_fft_scale_set": 16384,
|
119 |
+
"primary_stem": "Bass"
|
120 |
+
},
|
121 |
+
"ceed671467c1f64ebdfac8a2490d0d52": {
|
122 |
+
"compensate": 1.035,
|
123 |
+
"mdx_dim_f_set": 3072,
|
124 |
+
"mdx_dim_t_set": 8,
|
125 |
+
"mdx_n_fft_scale_set": 7680,
|
126 |
+
"primary_stem": "Instrumental"
|
127 |
+
},
|
128 |
+
"d2a1376f310e4f7fa37fb9b5774eb701": {
|
129 |
+
"compensate": 1.035,
|
130 |
+
"mdx_dim_f_set": 3072,
|
131 |
+
"mdx_dim_t_set": 8,
|
132 |
+
"mdx_n_fft_scale_set": 7680,
|
133 |
+
"primary_stem": "Instrumental"
|
134 |
+
},
|
135 |
+
"d7bff498db9324db933d913388cba6be": {
|
136 |
+
"compensate": 1.035,
|
137 |
+
"mdx_dim_f_set": 2048,
|
138 |
+
"mdx_dim_t_set": 8,
|
139 |
+
"mdx_n_fft_scale_set": 6144,
|
140 |
+
"primary_stem": "Vocals"
|
141 |
+
},
|
142 |
+
"d94058f8c7f1fae4164868ae8ae66b20": {
|
143 |
+
"compensate": 1.035,
|
144 |
+
"mdx_dim_f_set": 2048,
|
145 |
+
"mdx_dim_t_set": 8,
|
146 |
+
"mdx_n_fft_scale_set": 6144,
|
147 |
+
"primary_stem": "Vocals"
|
148 |
+
},
|
149 |
+
"dc41ede5961d50f277eb846db17f5319": {
|
150 |
+
"compensate": 1.035,
|
151 |
+
"mdx_dim_f_set": 2048,
|
152 |
+
"mdx_dim_t_set": 9,
|
153 |
+
"mdx_n_fft_scale_set": 4096,
|
154 |
+
"primary_stem": "Drums"
|
155 |
+
},
|
156 |
+
"e5572e58abf111f80d8241d2e44e7fa4": {
|
157 |
+
"compensate": 1.028,
|
158 |
+
"mdx_dim_f_set": 3072,
|
159 |
+
"mdx_dim_t_set": 8,
|
160 |
+
"mdx_n_fft_scale_set": 7680,
|
161 |
+
"primary_stem": "Instrumental"
|
162 |
+
},
|
163 |
+
"e7324c873b1f615c35c1967f912db92a": {
|
164 |
+
"compensate": 1.03,
|
165 |
+
"mdx_dim_f_set": 3072,
|
166 |
+
"mdx_dim_t_set": 8,
|
167 |
+
"mdx_n_fft_scale_set": 7680,
|
168 |
+
"primary_stem": "Vocals"
|
169 |
+
},
|
170 |
+
"1c56ec0224f1d559c42fd6fd2a67b154": {
|
171 |
+
"compensate": 1.025,
|
172 |
+
"mdx_dim_f_set": 2048,
|
173 |
+
"mdx_dim_t_set": 8,
|
174 |
+
"mdx_n_fft_scale_set": 5120,
|
175 |
+
"primary_stem": "Instrumental"
|
176 |
+
},
|
177 |
+
"f2df6d6863d8f435436d8b561594ff49": {
|
178 |
+
"compensate": 1.035,
|
179 |
+
"mdx_dim_f_set": 3072,
|
180 |
+
"mdx_dim_t_set": 8,
|
181 |
+
"mdx_n_fft_scale_set": 7680,
|
182 |
+
"primary_stem": "Instrumental"
|
183 |
+
},
|
184 |
+
"b06327a00d5e5fbc7d96e1781bbdb596": {
|
185 |
+
"compensate": 1.035,
|
186 |
+
"mdx_dim_f_set": 3072,
|
187 |
+
"mdx_dim_t_set": 8,
|
188 |
+
"mdx_n_fft_scale_set": 6144,
|
189 |
+
"primary_stem": "Instrumental"
|
190 |
+
},
|
191 |
+
"94ff780b977d3ca07c7a343dab2e25dd": {
|
192 |
+
"compensate": 1.039,
|
193 |
+
"mdx_dim_f_set": 3072,
|
194 |
+
"mdx_dim_t_set": 8,
|
195 |
+
"mdx_n_fft_scale_set": 6144,
|
196 |
+
"primary_stem": "Instrumental"
|
197 |
+
},
|
198 |
+
"73492b58195c3b52d34590d5474452f6": {
|
199 |
+
"compensate": 1.043,
|
200 |
+
"mdx_dim_f_set": 3072,
|
201 |
+
"mdx_dim_t_set": 8,
|
202 |
+
"mdx_n_fft_scale_set": 7680,
|
203 |
+
"primary_stem": "Vocals"
|
204 |
+
},
|
205 |
+
"970b3f9492014d18fefeedfe4773cb42": {
|
206 |
+
"compensate": 1.009,
|
207 |
+
"mdx_dim_f_set": 3072,
|
208 |
+
"mdx_dim_t_set": 8,
|
209 |
+
"mdx_n_fft_scale_set": 7680,
|
210 |
+
"primary_stem": "Vocals"
|
211 |
+
},
|
212 |
+
"1d64a6d2c30f709b8c9b4ce1366d96ee": {
|
213 |
+
"compensate": 1.035,
|
214 |
+
"mdx_dim_f_set": 2048,
|
215 |
+
"mdx_dim_t_set": 8,
|
216 |
+
"mdx_n_fft_scale_set": 5120,
|
217 |
+
"primary_stem": "Instrumental"
|
218 |
+
},
|
219 |
+
"203f2a3955221b64df85a41af87cf8f0": {
|
220 |
+
"compensate": 1.035,
|
221 |
+
"mdx_dim_f_set": 3072,
|
222 |
+
"mdx_dim_t_set": 8,
|
223 |
+
"mdx_n_fft_scale_set": 6144,
|
224 |
+
"primary_stem": "Instrumental"
|
225 |
+
},
|
226 |
+
"291c2049608edb52648b96e27eb80e95": {
|
227 |
+
"compensate": 1.035,
|
228 |
+
"mdx_dim_f_set": 3072,
|
229 |
+
"mdx_dim_t_set": 8,
|
230 |
+
"mdx_n_fft_scale_set": 6144,
|
231 |
+
"primary_stem": "Instrumental"
|
232 |
+
},
|
233 |
+
"ead8d05dab12ec571d67549b3aab03fc": {
|
234 |
+
"compensate": 1.035,
|
235 |
+
"mdx_dim_f_set": 3072,
|
236 |
+
"mdx_dim_t_set": 8,
|
237 |
+
"mdx_n_fft_scale_set": 6144,
|
238 |
+
"primary_stem": "Instrumental"
|
239 |
+
},
|
240 |
+
"cc63408db3d80b4d85b0287d1d7c9632": {
|
241 |
+
"compensate": 1.033,
|
242 |
+
"mdx_dim_f_set": 3072,
|
243 |
+
"mdx_dim_t_set": 8,
|
244 |
+
"mdx_n_fft_scale_set": 6144,
|
245 |
+
"primary_stem": "Instrumental"
|
246 |
+
},
|
247 |
+
"cd5b2989ad863f116c855db1dfe24e39": {
|
248 |
+
"compensate": 1.035,
|
249 |
+
"mdx_dim_f_set": 3072,
|
250 |
+
"mdx_dim_t_set": 9,
|
251 |
+
"mdx_n_fft_scale_set": 6144,
|
252 |
+
"primary_stem": "Other"
|
253 |
+
},
|
254 |
+
"55657dd70583b0fedfba5f67df11d711": {
|
255 |
+
"compensate": 1.022,
|
256 |
+
"mdx_dim_f_set": 3072,
|
257 |
+
"mdx_dim_t_set": 8,
|
258 |
+
"mdx_n_fft_scale_set": 6144,
|
259 |
+
"primary_stem": "Instrumental"
|
260 |
+
},
|
261 |
+
"b6bccda408a436db8500083ef3491e8b": {
|
262 |
+
"compensate": 1.02,
|
263 |
+
"mdx_dim_f_set": 3072,
|
264 |
+
"mdx_dim_t_set": 8,
|
265 |
+
"mdx_n_fft_scale_set": 7680,
|
266 |
+
"primary_stem": "Instrumental"
|
267 |
+
},
|
268 |
+
"8a88db95c7fb5dbe6a095ff2ffb428b1": {
|
269 |
+
"compensate": 1.026,
|
270 |
+
"mdx_dim_f_set": 2048,
|
271 |
+
"mdx_dim_t_set": 8,
|
272 |
+
"mdx_n_fft_scale_set": 5120,
|
273 |
+
"primary_stem": "Instrumental"
|
274 |
+
},
|
275 |
+
"b78da4afc6512f98e4756f5977f5c6b9": {
|
276 |
+
"compensate": 1.021,
|
277 |
+
"mdx_dim_f_set": 3072,
|
278 |
+
"mdx_dim_t_set": 8,
|
279 |
+
"mdx_n_fft_scale_set": 7680,
|
280 |
+
"primary_stem": "Instrumental"
|
281 |
+
},
|
282 |
+
"77d07b2667ddf05b9e3175941b4454a0": {
|
283 |
+
"compensate": 1.021,
|
284 |
+
"mdx_dim_f_set": 3072,
|
285 |
+
"mdx_dim_t_set": 8,
|
286 |
+
"mdx_n_fft_scale_set": 7680,
|
287 |
+
"primary_stem": "Vocals"
|
288 |
+
},
|
289 |
+
"0f2a6bc5b49d87d64728ee40e23bceb1": {
|
290 |
+
"compensate": 1.019,
|
291 |
+
"mdx_dim_f_set": 2560,
|
292 |
+
"mdx_dim_t_set": 8,
|
293 |
+
"mdx_n_fft_scale_set": 5120,
|
294 |
+
"primary_stem": "Instrumental"
|
295 |
+
},
|
296 |
+
"b02be2d198d4968a121030cf8950b492": {
|
297 |
+
"compensate": 1.020,
|
298 |
+
"mdx_dim_f_set": 2560,
|
299 |
+
"mdx_dim_t_set": 8,
|
300 |
+
"mdx_n_fft_scale_set": 5120,
|
301 |
+
"primary_stem": "No Crowd"
|
302 |
+
},
|
303 |
+
"2154254ee89b2945b97a7efed6e88820": {
|
304 |
+
"config_yaml": "model_2_stem_061321.yaml"
|
305 |
+
},
|
306 |
+
"063aadd735d58150722926dcbf5852a9": {
|
307 |
+
"config_yaml": "model_2_stem_061321.yaml"
|
308 |
+
},
|
309 |
+
"fe96801369f6a148df2720f5ced88c19": {
|
310 |
+
"config_yaml": "model3.yaml"
|
311 |
+
},
|
312 |
+
"02e8b226f85fb566e5db894b9931c640": {
|
313 |
+
"config_yaml": "model2.yaml"
|
314 |
+
},
|
315 |
+
"e3de6d861635ab9c1d766149edd680d6": {
|
316 |
+
"config_yaml": "model1.yaml"
|
317 |
+
},
|
318 |
+
"3f2936c554ab73ce2e396d54636bd373": {
|
319 |
+
"config_yaml": "modelB.yaml"
|
320 |
+
},
|
321 |
+
"890d0f6f82d7574bca741a9e8bcb8168": {
|
322 |
+
"config_yaml": "modelB.yaml"
|
323 |
+
},
|
324 |
+
"63a3cb8c37c474681049be4ad1ba8815": {
|
325 |
+
"config_yaml": "modelB.yaml"
|
326 |
+
},
|
327 |
+
"a7fc5d719743c7fd6b61bd2b4d48b9f0": {
|
328 |
+
"config_yaml": "modelA.yaml"
|
329 |
+
},
|
330 |
+
"3567f3dee6e77bf366fcb1c7b8bc3745": {
|
331 |
+
"config_yaml": "modelA.yaml"
|
332 |
+
},
|
333 |
+
"a28f4d717bd0d34cd2ff7a3b0a3d065e": {
|
334 |
+
"config_yaml": "modelA.yaml"
|
335 |
+
},
|
336 |
+
"c9971a18da20911822593dc81caa8be9": {
|
337 |
+
"config_yaml": "sndfx.yaml"
|
338 |
+
},
|
339 |
+
"57d94d5ed705460d21c75a5ac829a605": {
|
340 |
+
"config_yaml": "sndfx.yaml"
|
341 |
+
},
|
342 |
+
"e7a25f8764f25a52c1b96c4946e66ba2": {
|
343 |
+
"config_yaml": "sndfx.yaml"
|
344 |
+
},
|
345 |
+
"104081d24e37217086ce5fde09147ee1": {
|
346 |
+
"config_yaml": "model_2_stem_061321.yaml"
|
347 |
+
},
|
348 |
+
"1e6165b601539f38d0a9330f3facffeb": {
|
349 |
+
"config_yaml": "model_2_stem_061321.yaml"
|
350 |
+
},
|
351 |
+
"fe0108464ce0d8271be5ab810891bd7c": {
|
352 |
+
"config_yaml": "model_2_stem_full_band.yaml"
|
353 |
+
}
|
354 |
+
}
|
packages.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
git-lfs
|
2 |
+
aria2 -y
|
3 |
+
ffmpeg
|
pre-requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu118
|
2 |
+
torch>=2.1.0+cu118
|
3 |
+
torchvision>=0.16.0+cu118
|
4 |
+
torchaudio>=2.1.0+cu118
|
5 |
+
yt-dlp
|
6 |
+
gradio==4.19.2
|
7 |
+
pydub==0.25.1
|
8 |
+
edge_tts==6.1.7
|
9 |
+
deep_translator==1.11.4
|
10 |
+
git+https://github.com/R3gm/pyannote-audio.git@3.1.1
|
11 |
+
git+https://github.com/R3gm/whisperX.git@cuda_11_8
|
12 |
+
nest_asyncio
|
13 |
+
gTTS
|
14 |
+
gradio_client==0.10.1
|
15 |
+
IPython
|
requirements.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
praat-parselmouth>=0.4.3
|
2 |
+
pyworld==0.3.2
|
3 |
+
faiss-cpu==1.7.3
|
4 |
+
torchcrepe==0.0.20
|
5 |
+
ffmpeg-python>=0.2.0
|
6 |
+
fairseq==0.12.2
|
7 |
+
gdown
|
8 |
+
rarfile
|
9 |
+
transformers
|
10 |
+
accelerate
|
11 |
+
optimum
|
12 |
+
sentencepiece
|
13 |
+
srt
|
14 |
+
git+https://github.com/R3gm/openvoice_package.git@lite
|
15 |
+
openai==1.14.3
|
16 |
+
tiktoken==0.6.0
|
17 |
+
# Documents
|
18 |
+
pypdf==4.2.0
|
19 |
+
python-docx
|
requirements_xtts.txt
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# core deps
|
2 |
+
numpy==1.23.5
|
3 |
+
cython>=0.29.30
|
4 |
+
scipy>=1.11.2
|
5 |
+
torch
|
6 |
+
torchaudio
|
7 |
+
soundfile
|
8 |
+
librosa
|
9 |
+
scikit-learn
|
10 |
+
numba
|
11 |
+
inflect>=5.6.0
|
12 |
+
tqdm>=4.64.1
|
13 |
+
anyascii>=0.3.0
|
14 |
+
pyyaml>=6.0
|
15 |
+
fsspec>=2023.6.0 # <= 2023.9.1 makes aux tests fail
|
16 |
+
aiohttp>=3.8.1
|
17 |
+
packaging>=23.1
|
18 |
+
# deps for examples
|
19 |
+
flask>=2.0.1
|
20 |
+
# deps for inference
|
21 |
+
pysbd>=0.3.4
|
22 |
+
# deps for notebooks
|
23 |
+
umap-learn>=0.5.1
|
24 |
+
pandas
|
25 |
+
# deps for training
|
26 |
+
matplotlib
|
27 |
+
# coqui stack
|
28 |
+
trainer>=0.0.32
|
29 |
+
# config management
|
30 |
+
coqpit>=0.0.16
|
31 |
+
# chinese g2p deps
|
32 |
+
jieba
|
33 |
+
pypinyin
|
34 |
+
# korean
|
35 |
+
hangul_romanize
|
36 |
+
# gruut+supported langs
|
37 |
+
gruut[de,es,fr]==2.2.3
|
38 |
+
# deps for korean
|
39 |
+
jamo
|
40 |
+
nltk
|
41 |
+
g2pkk>=0.1.1
|
42 |
+
# deps for bangla
|
43 |
+
bangla
|
44 |
+
bnnumerizer
|
45 |
+
bnunicodenormalizer
|
46 |
+
#deps for tortoise
|
47 |
+
einops>=0.6.0
|
48 |
+
transformers
|
49 |
+
#deps for bark
|
50 |
+
encodec>=0.1.1
|
51 |
+
# deps for XTTS
|
52 |
+
unidecode>=1.3.2
|
53 |
+
num2words
|
54 |
+
spacy[ja]>=3
|
55 |
+
|
56 |
+
# after this
|
57 |
+
# pip install -r requirements_xtts.txt
|
58 |
+
# pip install TTS==0.21.1 --no-deps
|
soni_translate/audio_segments.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydub import AudioSegment
|
2 |
+
from tqdm import tqdm
|
3 |
+
from .utils import run_command
|
4 |
+
from .logging_setup import logger
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
class Mixer:
|
9 |
+
def __init__(self):
|
10 |
+
self.parts = []
|
11 |
+
|
12 |
+
def __len__(self):
|
13 |
+
parts = self._sync()
|
14 |
+
seg = parts[0][1]
|
15 |
+
frame_count = max(offset + seg.frame_count() for offset, seg in parts)
|
16 |
+
return int(1000.0 * frame_count / seg.frame_rate)
|
17 |
+
|
18 |
+
def overlay(self, sound, position=0):
|
19 |
+
self.parts.append((position, sound))
|
20 |
+
return self
|
21 |
+
|
22 |
+
def _sync(self):
|
23 |
+
positions, segs = zip(*self.parts)
|
24 |
+
|
25 |
+
frame_rate = segs[0].frame_rate
|
26 |
+
array_type = segs[0].array_type # noqa
|
27 |
+
|
28 |
+
offsets = [int(frame_rate * pos / 1000.0) for pos in positions]
|
29 |
+
segs = AudioSegment.empty()._sync(*segs)
|
30 |
+
return list(zip(offsets, segs))
|
31 |
+
|
32 |
+
def append(self, sound):
|
33 |
+
self.overlay(sound, position=len(self))
|
34 |
+
|
35 |
+
def to_audio_segment(self):
|
36 |
+
parts = self._sync()
|
37 |
+
seg = parts[0][1]
|
38 |
+
channels = seg.channels
|
39 |
+
|
40 |
+
frame_count = max(offset + seg.frame_count() for offset, seg in parts)
|
41 |
+
sample_count = int(frame_count * seg.channels)
|
42 |
+
|
43 |
+
output = np.zeros(sample_count, dtype="int32")
|
44 |
+
for offset, seg in parts:
|
45 |
+
sample_offset = offset * channels
|
46 |
+
samples = np.frombuffer(seg.get_array_of_samples(), dtype="int32")
|
47 |
+
samples = np.int16(samples/np.max(np.abs(samples)) * 32767)
|
48 |
+
start = sample_offset
|
49 |
+
end = start + len(samples)
|
50 |
+
output[start:end] += samples
|
51 |
+
|
52 |
+
return seg._spawn(
|
53 |
+
output, overrides={"sample_width": 4}).normalize(headroom=0.0)
|
54 |
+
|
55 |
+
|
56 |
+
def create_translated_audio(
|
57 |
+
result_diarize, audio_files, final_file, concat=False, avoid_overlap=False,
|
58 |
+
):
|
59 |
+
total_duration = result_diarize["segments"][-1]["end"] # in seconds
|
60 |
+
|
61 |
+
if concat:
|
62 |
+
"""
|
63 |
+
file .\audio\1.ogg
|
64 |
+
file .\audio\2.ogg
|
65 |
+
file .\audio\3.ogg
|
66 |
+
file .\audio\4.ogg
|
67 |
+
...
|
68 |
+
"""
|
69 |
+
|
70 |
+
# Write the file paths to list.txt
|
71 |
+
with open("list.txt", "w") as file:
|
72 |
+
for i, audio_file in enumerate(audio_files):
|
73 |
+
if i == len(audio_files) - 1: # Check if it's the last item
|
74 |
+
file.write(f"file {audio_file}")
|
75 |
+
else:
|
76 |
+
file.write(f"file {audio_file}\n")
|
77 |
+
|
78 |
+
# command = f"ffmpeg -f concat -safe 0 -i list.txt {final_file}"
|
79 |
+
command = (
|
80 |
+
f"ffmpeg -f concat -safe 0 -i list.txt -c:a pcm_s16le {final_file}"
|
81 |
+
)
|
82 |
+
run_command(command)
|
83 |
+
|
84 |
+
else:
|
85 |
+
# silent audio with total_duration
|
86 |
+
base_audio = AudioSegment.silent(
|
87 |
+
duration=int(total_duration * 1000), frame_rate=41000
|
88 |
+
)
|
89 |
+
combined_audio = Mixer()
|
90 |
+
combined_audio.overlay(base_audio)
|
91 |
+
|
92 |
+
logger.debug(
|
93 |
+
f"Audio duration: {total_duration // 60} "
|
94 |
+
f"minutes and {int(total_duration % 60)} seconds"
|
95 |
+
)
|
96 |
+
|
97 |
+
last_end_time = 0
|
98 |
+
previous_speaker = ""
|
99 |
+
for line, audio_file in tqdm(
|
100 |
+
zip(result_diarize["segments"], audio_files)
|
101 |
+
):
|
102 |
+
start = float(line["start"])
|
103 |
+
|
104 |
+
# Overlay each audio at the corresponding time
|
105 |
+
try:
|
106 |
+
audio = AudioSegment.from_file(audio_file)
|
107 |
+
# audio_a = audio.speedup(playback_speed=1.5)
|
108 |
+
|
109 |
+
if avoid_overlap:
|
110 |
+
speaker = line["speaker"]
|
111 |
+
if (last_end_time - 0.500) > start:
|
112 |
+
overlap_time = last_end_time - start
|
113 |
+
if previous_speaker and previous_speaker != speaker:
|
114 |
+
start = (last_end_time - 0.500)
|
115 |
+
else:
|
116 |
+
start = (last_end_time - 0.200)
|
117 |
+
if overlap_time > 2.5:
|
118 |
+
start = start - 0.3
|
119 |
+
logger.info(
|
120 |
+
f"Avoid overlap for {str(audio_file)} "
|
121 |
+
f"with {str(start)}"
|
122 |
+
)
|
123 |
+
|
124 |
+
previous_speaker = speaker
|
125 |
+
|
126 |
+
duration_tts_seconds = len(audio) / 1000.0 # to sec
|
127 |
+
last_end_time = (start + duration_tts_seconds)
|
128 |
+
|
129 |
+
start_time = start * 1000 # to ms
|
130 |
+
combined_audio = combined_audio.overlay(
|
131 |
+
audio, position=start_time
|
132 |
+
)
|
133 |
+
except Exception as error:
|
134 |
+
logger.debug(str(error))
|
135 |
+
logger.error(f"Error audio file {audio_file}")
|
136 |
+
|
137 |
+
# combined audio as a file
|
138 |
+
combined_audio_data = combined_audio.to_audio_segment()
|
139 |
+
combined_audio_data.export(
|
140 |
+
final_file, format="wav"
|
141 |
+
) # best than ogg, change if the audio is anomalous
|
soni_translate/language_configuration.py
ADDED
@@ -0,0 +1,551 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .logging_setup import logger
|
2 |
+
|
3 |
+
LANGUAGES_UNIDIRECTIONAL = {
|
4 |
+
"Aymara (ay)": "ay",
|
5 |
+
"Bambara (bm)": "bm",
|
6 |
+
"Cebuano (ceb)": "ceb",
|
7 |
+
"Chichewa (ny)": "ny",
|
8 |
+
"Divehi (dv)": "dv",
|
9 |
+
"Dogri (doi)": "doi",
|
10 |
+
"Ewe (ee)": "ee",
|
11 |
+
"Guarani (gn)": "gn",
|
12 |
+
"Iloko (ilo)": "ilo",
|
13 |
+
"Kinyarwanda (rw)": "rw",
|
14 |
+
"Krio (kri)": "kri",
|
15 |
+
"Kurdish (ku)": "ku",
|
16 |
+
"Kirghiz (ky)": "ky",
|
17 |
+
"Ganda (lg)": "lg",
|
18 |
+
"Maithili (mai)": "mai",
|
19 |
+
"Oriya (or)": "or",
|
20 |
+
"Oromo (om)": "om",
|
21 |
+
"Quechua (qu)": "qu",
|
22 |
+
"Samoan (sm)": "sm",
|
23 |
+
"Tigrinya (ti)": "ti",
|
24 |
+
"Tsonga (ts)": "ts",
|
25 |
+
"Akan (ak)": "ak",
|
26 |
+
"Uighur (ug)": "ug"
|
27 |
+
}
|
28 |
+
|
29 |
+
UNIDIRECTIONAL_L_LIST = LANGUAGES_UNIDIRECTIONAL.keys()
|
30 |
+
|
31 |
+
LANGUAGES = {
|
32 |
+
"Automatic detection": "Automatic detection",
|
33 |
+
"Arabic (ar)": "ar",
|
34 |
+
"Chinese - Simplified (zh-CN)": "zh",
|
35 |
+
"Czech (cs)": "cs",
|
36 |
+
"Danish (da)": "da",
|
37 |
+
"Dutch (nl)": "nl",
|
38 |
+
"English (en)": "en",
|
39 |
+
"Finnish (fi)": "fi",
|
40 |
+
"French (fr)": "fr",
|
41 |
+
"German (de)": "de",
|
42 |
+
"Greek (el)": "el",
|
43 |
+
"Hebrew (he)": "he",
|
44 |
+
"Hungarian (hu)": "hu",
|
45 |
+
"Italian (it)": "it",
|
46 |
+
"Japanese (ja)": "ja",
|
47 |
+
"Korean (ko)": "ko",
|
48 |
+
"Persian (fa)": "fa", # no aux gTTS
|
49 |
+
"Polish (pl)": "pl",
|
50 |
+
"Portuguese (pt)": "pt",
|
51 |
+
"Russian (ru)": "ru",
|
52 |
+
"Spanish (es)": "es",
|
53 |
+
"Turkish (tr)": "tr",
|
54 |
+
"Ukrainian (uk)": "uk",
|
55 |
+
"Urdu (ur)": "ur",
|
56 |
+
"Vietnamese (vi)": "vi",
|
57 |
+
"Hindi (hi)": "hi",
|
58 |
+
"Indonesian (id)": "id",
|
59 |
+
"Bengali (bn)": "bn",
|
60 |
+
"Telugu (te)": "te",
|
61 |
+
"Marathi (mr)": "mr",
|
62 |
+
"Tamil (ta)": "ta",
|
63 |
+
"Javanese (jw|jv)": "jw",
|
64 |
+
"Catalan (ca)": "ca",
|
65 |
+
"Nepali (ne)": "ne",
|
66 |
+
"Thai (th)": "th",
|
67 |
+
"Swedish (sv)": "sv",
|
68 |
+
"Amharic (am)": "am",
|
69 |
+
"Welsh (cy)": "cy", # no aux gTTS
|
70 |
+
"Estonian (et)": "et",
|
71 |
+
"Croatian (hr)": "hr",
|
72 |
+
"Icelandic (is)": "is",
|
73 |
+
"Georgian (ka)": "ka", # no aux gTTS
|
74 |
+
"Khmer (km)": "km",
|
75 |
+
"Slovak (sk)": "sk",
|
76 |
+
"Albanian (sq)": "sq",
|
77 |
+
"Serbian (sr)": "sr",
|
78 |
+
"Azerbaijani (az)": "az", # no aux gTTS
|
79 |
+
"Bulgarian (bg)": "bg",
|
80 |
+
"Galician (gl)": "gl", # no aux gTTS
|
81 |
+
"Gujarati (gu)": "gu",
|
82 |
+
"Kazakh (kk)": "kk", # no aux gTTS
|
83 |
+
"Kannada (kn)": "kn",
|
84 |
+
"Lithuanian (lt)": "lt", # no aux gTTS
|
85 |
+
"Latvian (lv)": "lv",
|
86 |
+
"Macedonian (mk)": "mk", # no aux gTTS # error get align model
|
87 |
+
"Malayalam (ml)": "ml",
|
88 |
+
"Malay (ms)": "ms", # error get align model
|
89 |
+
"Romanian (ro)": "ro",
|
90 |
+
"Sinhala (si)": "si",
|
91 |
+
"Sundanese (su)": "su",
|
92 |
+
"Swahili (sw)": "sw", # error aling
|
93 |
+
"Afrikaans (af)": "af",
|
94 |
+
"Bosnian (bs)": "bs",
|
95 |
+
"Latin (la)": "la",
|
96 |
+
"Myanmar Burmese (my)": "my",
|
97 |
+
"Norwegian (no|nb)": "no",
|
98 |
+
"Chinese - Traditional (zh-TW)": "zh-TW",
|
99 |
+
"Assamese (as)": "as",
|
100 |
+
"Basque (eu)": "eu",
|
101 |
+
"Hausa (ha)": "ha",
|
102 |
+
"Haitian Creole (ht)": "ht",
|
103 |
+
"Armenian (hy)": "hy",
|
104 |
+
"Lao (lo)": "lo",
|
105 |
+
"Malagasy (mg)": "mg",
|
106 |
+
"Mongolian (mn)": "mn",
|
107 |
+
"Maltese (mt)": "mt",
|
108 |
+
"Punjabi (pa)": "pa",
|
109 |
+
"Pashto (ps)": "ps",
|
110 |
+
"Slovenian (sl)": "sl",
|
111 |
+
"Shona (sn)": "sn",
|
112 |
+
"Somali (so)": "so",
|
113 |
+
"Tajik (tg)": "tg",
|
114 |
+
"Turkmen (tk)": "tk",
|
115 |
+
"Tatar (tt)": "tt",
|
116 |
+
"Uzbek (uz)": "uz",
|
117 |
+
"Yoruba (yo)": "yo",
|
118 |
+
**LANGUAGES_UNIDIRECTIONAL
|
119 |
+
}
|
120 |
+
|
121 |
+
BASE_L_LIST = LANGUAGES.keys()
|
122 |
+
LANGUAGES_LIST = [list(BASE_L_LIST)[0]] + sorted(list(BASE_L_LIST)[1:])
|
123 |
+
INVERTED_LANGUAGES = {value: key for key, value in LANGUAGES.items()}
|
124 |
+
|
125 |
+
EXTRA_ALIGN = {
|
126 |
+
"id": "indonesian-nlp/wav2vec2-large-xlsr-indonesian",
|
127 |
+
"bn": "arijitx/wav2vec2-large-xlsr-bengali",
|
128 |
+
"mr": "sumedh/wav2vec2-large-xlsr-marathi",
|
129 |
+
"ta": "Amrrs/wav2vec2-large-xlsr-53-tamil",
|
130 |
+
"jw": "cahya/wav2vec2-large-xlsr-javanese",
|
131 |
+
"ne": "shniranjan/wav2vec2-large-xlsr-300m-nepali",
|
132 |
+
"th": "sakares/wav2vec2-large-xlsr-thai-demo",
|
133 |
+
"sv": "KBLab/wav2vec2-large-voxrex-swedish",
|
134 |
+
"am": "agkphysics/wav2vec2-large-xlsr-53-amharic",
|
135 |
+
"cy": "Srulikbdd/Wav2Vec2-large-xlsr-welsh",
|
136 |
+
"et": "anton-l/wav2vec2-large-xlsr-53-estonian",
|
137 |
+
"hr": "classla/wav2vec2-xls-r-parlaspeech-hr",
|
138 |
+
"is": "carlosdanielhernandezmena/wav2vec2-large-xlsr-53-icelandic-ep10-1000h",
|
139 |
+
"ka": "MehdiHosseiniMoghadam/wav2vec2-large-xlsr-53-Georgian",
|
140 |
+
"km": "vitouphy/wav2vec2-xls-r-300m-khmer",
|
141 |
+
"sk": "infinitejoy/wav2vec2-large-xls-r-300m-slovak",
|
142 |
+
"sq": "Alimzhan/wav2vec2-large-xls-r-300m-albanian-colab",
|
143 |
+
"sr": "dnikolic/wav2vec2-xlsr-530-serbian-colab",
|
144 |
+
"az": "nijatzeynalov/wav2vec2-large-mms-1b-azerbaijani-common_voice15.0",
|
145 |
+
"bg": "infinitejoy/wav2vec2-large-xls-r-300m-bulgarian",
|
146 |
+
"gl": "ifrz/wav2vec2-large-xlsr-galician",
|
147 |
+
"gu": "Harveenchadha/vakyansh-wav2vec2-gujarati-gnm-100",
|
148 |
+
"kk": "aismlv/wav2vec2-large-xlsr-kazakh",
|
149 |
+
"kn": "Harveenchadha/vakyansh-wav2vec2-kannada-knm-560",
|
150 |
+
"lt": "DeividasM/wav2vec2-large-xlsr-53-lithuanian",
|
151 |
+
"lv": "anton-l/wav2vec2-large-xlsr-53-latvian",
|
152 |
+
"mk": "", # Konstantin-Bogdanoski/wav2vec2-macedonian-base
|
153 |
+
"ml": "gvs/wav2vec2-large-xlsr-malayalam",
|
154 |
+
"ms": "", # Duy/wav2vec2_malay
|
155 |
+
"ro": "anton-l/wav2vec2-large-xlsr-53-romanian",
|
156 |
+
"si": "IAmNotAnanth/wav2vec2-large-xls-r-300m-sinhala",
|
157 |
+
"su": "cahya/wav2vec2-large-xlsr-sundanese",
|
158 |
+
"sw": "", # Lians/fine-tune-wav2vec2-large-swahili
|
159 |
+
"af": "", # ylacombe/wav2vec2-common_voice-af-demo
|
160 |
+
"bs": "",
|
161 |
+
"la": "",
|
162 |
+
"my": "",
|
163 |
+
"no": "NbAiLab/wav2vec2-xlsr-300m-norwegian",
|
164 |
+
"zh-TW": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn",
|
165 |
+
"as": "",
|
166 |
+
"eu": "", # cahya/wav2vec2-large-xlsr-basque # verify
|
167 |
+
"ha": "infinitejoy/wav2vec2-large-xls-r-300m-hausa",
|
168 |
+
"ht": "",
|
169 |
+
"hy": "infinitejoy/wav2vec2-large-xls-r-300m-armenian", # no (.)
|
170 |
+
"lo": "",
|
171 |
+
"mg": "",
|
172 |
+
"mn": "tugstugi/wav2vec2-large-xlsr-53-mongolian",
|
173 |
+
"mt": "carlosdanielhernandezmena/wav2vec2-large-xlsr-53-maltese-64h",
|
174 |
+
"pa": "kingabzpro/wav2vec2-large-xlsr-53-punjabi",
|
175 |
+
"ps": "aamirhs/wav2vec2-large-xls-r-300m-pashto-colab",
|
176 |
+
"sl": "anton-l/wav2vec2-large-xlsr-53-slovenian",
|
177 |
+
"sn": "",
|
178 |
+
"so": "",
|
179 |
+
"tg": "",
|
180 |
+
"tk": "", # Ragav/wav2vec2-tk
|
181 |
+
"tt": "anton-l/wav2vec2-large-xlsr-53-tatar",
|
182 |
+
"uz": "", # Mekhriddin/wav2vec2-large-xls-r-300m-uzbek-colab
|
183 |
+
"yo": "ogbi/wav2vec2-large-mms-1b-yoruba-test",
|
184 |
+
}
|
185 |
+
|
186 |
+
|
187 |
+
def fix_code_language(translate_to, syntax="google"):
|
188 |
+
if syntax == "google":
|
189 |
+
# google-translator, gTTS
|
190 |
+
replace_lang_code = {"zh": "zh-CN", "he": "iw", "zh-cn": "zh-CN"}
|
191 |
+
elif syntax == "coqui":
|
192 |
+
# coqui-xtts
|
193 |
+
replace_lang_code = {"zh": "zh-cn", "zh-CN": "zh-cn", "zh-TW": "zh-cn"}
|
194 |
+
|
195 |
+
new_code_lang = replace_lang_code.get(translate_to, translate_to)
|
196 |
+
logger.debug(f"Fix code {translate_to} -> {new_code_lang}")
|
197 |
+
return new_code_lang
|
198 |
+
|
199 |
+
|
200 |
+
BARK_VOICES_LIST = {
|
201 |
+
"de_speaker_0-Male BARK": "v2/de_speaker_0",
|
202 |
+
"de_speaker_1-Male BARK": "v2/de_speaker_1",
|
203 |
+
"de_speaker_2-Male BARK": "v2/de_speaker_2",
|
204 |
+
"de_speaker_3-Female BARK": "v2/de_speaker_3",
|
205 |
+
"de_speaker_4-Male BARK": "v2/de_speaker_4",
|
206 |
+
"de_speaker_5-Male BARK": "v2/de_speaker_5",
|
207 |
+
"de_speaker_6-Male BARK": "v2/de_speaker_6",
|
208 |
+
"de_speaker_7-Male BARK": "v2/de_speaker_7",
|
209 |
+
"de_speaker_8-Female BARK": "v2/de_speaker_8",
|
210 |
+
"de_speaker_9-Male BARK": "v2/de_speaker_9",
|
211 |
+
"en_speaker_0-Male BARK": "v2/en_speaker_0",
|
212 |
+
"en_speaker_1-Male BARK": "v2/en_speaker_1",
|
213 |
+
"en_speaker_2-Male BARK": "v2/en_speaker_2",
|
214 |
+
"en_speaker_3-Male BARK": "v2/en_speaker_3",
|
215 |
+
"en_speaker_4-Male BARK": "v2/en_speaker_4",
|
216 |
+
"en_speaker_5-Male BARK": "v2/en_speaker_5",
|
217 |
+
"en_speaker_6-Male BARK": "v2/en_speaker_6",
|
218 |
+
"en_speaker_7-Male BARK": "v2/en_speaker_7",
|
219 |
+
"en_speaker_8-Male BARK": "v2/en_speaker_8",
|
220 |
+
"en_speaker_9-Female BARK": "v2/en_speaker_9",
|
221 |
+
"es_speaker_0-Male BARK": "v2/es_speaker_0",
|
222 |
+
"es_speaker_1-Male BARK": "v2/es_speaker_1",
|
223 |
+
"es_speaker_2-Male BARK": "v2/es_speaker_2",
|
224 |
+
"es_speaker_3-Male BARK": "v2/es_speaker_3",
|
225 |
+
"es_speaker_4-Male BARK": "v2/es_speaker_4",
|
226 |
+
"es_speaker_5-Male BARK": "v2/es_speaker_5",
|
227 |
+
"es_speaker_6-Male BARK": "v2/es_speaker_6",
|
228 |
+
"es_speaker_7-Male BARK": "v2/es_speaker_7",
|
229 |
+
"es_speaker_8-Female BARK": "v2/es_speaker_8",
|
230 |
+
"es_speaker_9-Female BARK": "v2/es_speaker_9",
|
231 |
+
"fr_speaker_0-Male BARK": "v2/fr_speaker_0",
|
232 |
+
"fr_speaker_1-Female BARK": "v2/fr_speaker_1",
|
233 |
+
"fr_speaker_2-Female BARK": "v2/fr_speaker_2",
|
234 |
+
"fr_speaker_3-Male BARK": "v2/fr_speaker_3",
|
235 |
+
"fr_speaker_4-Male BARK": "v2/fr_speaker_4",
|
236 |
+
"fr_speaker_5-Female BARK": "v2/fr_speaker_5",
|
237 |
+
"fr_speaker_6-Male BARK": "v2/fr_speaker_6",
|
238 |
+
"fr_speaker_7-Male BARK": "v2/fr_speaker_7",
|
239 |
+
"fr_speaker_8-Male BARK": "v2/fr_speaker_8",
|
240 |
+
"fr_speaker_9-Male BARK": "v2/fr_speaker_9",
|
241 |
+
"hi_speaker_0-Female BARK": "v2/hi_speaker_0",
|
242 |
+
"hi_speaker_1-Female BARK": "v2/hi_speaker_1",
|
243 |
+
"hi_speaker_2-Male BARK": "v2/hi_speaker_2",
|
244 |
+
"hi_speaker_3-Female BARK": "v2/hi_speaker_3",
|
245 |
+
"hi_speaker_4-Female BARK": "v2/hi_speaker_4",
|
246 |
+
"hi_speaker_5-Male BARK": "v2/hi_speaker_5",
|
247 |
+
"hi_speaker_6-Male BARK": "v2/hi_speaker_6",
|
248 |
+
"hi_speaker_7-Male BARK": "v2/hi_speaker_7",
|
249 |
+
"hi_speaker_8-Male BARK": "v2/hi_speaker_8",
|
250 |
+
"hi_speaker_9-Female BARK": "v2/hi_speaker_9",
|
251 |
+
"it_speaker_0-Male BARK": "v2/it_speaker_0",
|
252 |
+
"it_speaker_1-Male BARK": "v2/it_speaker_1",
|
253 |
+
"it_speaker_2-Female BARK": "v2/it_speaker_2",
|
254 |
+
"it_speaker_3-Male BARK": "v2/it_speaker_3",
|
255 |
+
"it_speaker_4-Male BARK": "v2/it_speaker_4",
|
256 |
+
"it_speaker_5-Male BARK": "v2/it_speaker_5",
|
257 |
+
"it_speaker_6-Male BARK": "v2/it_speaker_6",
|
258 |
+
"it_speaker_7-Female BARK": "v2/it_speaker_7",
|
259 |
+
"it_speaker_8-Male BARK": "v2/it_speaker_8",
|
260 |
+
"it_speaker_9-Female BARK": "v2/it_speaker_9",
|
261 |
+
"ja_speaker_0-Female BARK": "v2/ja_speaker_0",
|
262 |
+
"ja_speaker_1-Female BARK": "v2/ja_speaker_1",
|
263 |
+
"ja_speaker_2-Male BARK": "v2/ja_speaker_2",
|
264 |
+
"ja_speaker_3-Female BARK": "v2/ja_speaker_3",
|
265 |
+
"ja_speaker_4-Female BARK": "v2/ja_speaker_4",
|
266 |
+
"ja_speaker_5-Female BARK": "v2/ja_speaker_5",
|
267 |
+
"ja_speaker_6-Male BARK": "v2/ja_speaker_6",
|
268 |
+
"ja_speaker_7-Female BARK": "v2/ja_speaker_7",
|
269 |
+
"ja_speaker_8-Female BARK": "v2/ja_speaker_8",
|
270 |
+
"ja_speaker_9-Female BARK": "v2/ja_speaker_9",
|
271 |
+
"ko_speaker_0-Female BARK": "v2/ko_speaker_0",
|
272 |
+
"ko_speaker_1-Male BARK": "v2/ko_speaker_1",
|
273 |
+
"ko_speaker_2-Male BARK": "v2/ko_speaker_2",
|
274 |
+
"ko_speaker_3-Male BARK": "v2/ko_speaker_3",
|
275 |
+
"ko_speaker_4-Male BARK": "v2/ko_speaker_4",
|
276 |
+
"ko_speaker_5-Male BARK": "v2/ko_speaker_5",
|
277 |
+
"ko_speaker_6-Male BARK": "v2/ko_speaker_6",
|
278 |
+
"ko_speaker_7-Male BARK": "v2/ko_speaker_7",
|
279 |
+
"ko_speaker_8-Male BARK": "v2/ko_speaker_8",
|
280 |
+
"ko_speaker_9-Male BARK": "v2/ko_speaker_9",
|
281 |
+
"pl_speaker_0-Male BARK": "v2/pl_speaker_0",
|
282 |
+
"pl_speaker_1-Male BARK": "v2/pl_speaker_1",
|
283 |
+
"pl_speaker_2-Male BARK": "v2/pl_speaker_2",
|
284 |
+
"pl_speaker_3-Male BARK": "v2/pl_speaker_3",
|
285 |
+
"pl_speaker_4-Female BARK": "v2/pl_speaker_4",
|
286 |
+
"pl_speaker_5-Male BARK": "v2/pl_speaker_5",
|
287 |
+
"pl_speaker_6-Female BARK": "v2/pl_speaker_6",
|
288 |
+
"pl_speaker_7-Male BARK": "v2/pl_speaker_7",
|
289 |
+
"pl_speaker_8-Male BARK": "v2/pl_speaker_8",
|
290 |
+
"pl_speaker_9-Female BARK": "v2/pl_speaker_9",
|
291 |
+
"pt_speaker_0-Male BARK": "v2/pt_speaker_0",
|
292 |
+
"pt_speaker_1-Male BARK": "v2/pt_speaker_1",
|
293 |
+
"pt_speaker_2-Male BARK": "v2/pt_speaker_2",
|
294 |
+
"pt_speaker_3-Male BARK": "v2/pt_speaker_3",
|
295 |
+
"pt_speaker_4-Male BARK": "v2/pt_speaker_4",
|
296 |
+
"pt_speaker_5-Male BARK": "v2/pt_speaker_5",
|
297 |
+
"pt_speaker_6-Male BARK": "v2/pt_speaker_6",
|
298 |
+
"pt_speaker_7-Male BARK": "v2/pt_speaker_7",
|
299 |
+
"pt_speaker_8-Male BARK": "v2/pt_speaker_8",
|
300 |
+
"pt_speaker_9-Male BARK": "v2/pt_speaker_9",
|
301 |
+
"ru_speaker_0-Male BARK": "v2/ru_speaker_0",
|
302 |
+
"ru_speaker_1-Male BARK": "v2/ru_speaker_1",
|
303 |
+
"ru_speaker_2-Male BARK": "v2/ru_speaker_2",
|
304 |
+
"ru_speaker_3-Male BARK": "v2/ru_speaker_3",
|
305 |
+
"ru_speaker_4-Male BARK": "v2/ru_speaker_4",
|
306 |
+
"ru_speaker_5-Female BARK": "v2/ru_speaker_5",
|
307 |
+
"ru_speaker_6-Female BARK": "v2/ru_speaker_6",
|
308 |
+
"ru_speaker_7-Male BARK": "v2/ru_speaker_7",
|
309 |
+
"ru_speaker_8-Male BARK": "v2/ru_speaker_8",
|
310 |
+
"ru_speaker_9-Female BARK": "v2/ru_speaker_9",
|
311 |
+
"tr_speaker_0-Male BARK": "v2/tr_speaker_0",
|
312 |
+
"tr_speaker_1-Male BARK": "v2/tr_speaker_1",
|
313 |
+
"tr_speaker_2-Male BARK": "v2/tr_speaker_2",
|
314 |
+
"tr_speaker_3-Male BARK": "v2/tr_speaker_3",
|
315 |
+
"tr_speaker_4-Female BARK": "v2/tr_speaker_4",
|
316 |
+
"tr_speaker_5-Female BARK": "v2/tr_speaker_5",
|
317 |
+
"tr_speaker_6-Male BARK": "v2/tr_speaker_6",
|
318 |
+
"tr_speaker_7-Male BARK": "v2/tr_speaker_7",
|
319 |
+
"tr_speaker_8-Male BARK": "v2/tr_speaker_8",
|
320 |
+
"tr_speaker_9-Male BARK": "v2/tr_speaker_9",
|
321 |
+
"zh_speaker_0-Male BARK": "v2/zh_speaker_0",
|
322 |
+
"zh_speaker_1-Male BARK": "v2/zh_speaker_1",
|
323 |
+
"zh_speaker_2-Male BARK": "v2/zh_speaker_2",
|
324 |
+
"zh_speaker_3-Male BARK": "v2/zh_speaker_3",
|
325 |
+
"zh_speaker_4-Female BARK": "v2/zh_speaker_4",
|
326 |
+
"zh_speaker_5-Male BARK": "v2/zh_speaker_5",
|
327 |
+
"zh_speaker_6-Female BARK": "v2/zh_speaker_6",
|
328 |
+
"zh_speaker_7-Female BARK": "v2/zh_speaker_7",
|
329 |
+
"zh_speaker_8-Male BARK": "v2/zh_speaker_8",
|
330 |
+
"zh_speaker_9-Female BARK": "v2/zh_speaker_9",
|
331 |
+
}
|
332 |
+
|
333 |
+
VITS_VOICES_LIST = {
|
334 |
+
"ar-facebook-mms VITS": "facebook/mms-tts-ara",
|
335 |
+
# 'zh-facebook-mms VITS': 'facebook/mms-tts-cmn',
|
336 |
+
"zh_Hakka-facebook-mms VITS": "facebook/mms-tts-hak",
|
337 |
+
"zh_MinNan-facebook-mms VITS": "facebook/mms-tts-nan",
|
338 |
+
# 'cs-facebook-mms VITS': 'facebook/mms-tts-ces',
|
339 |
+
# 'da-facebook-mms VITS': 'facebook/mms-tts-dan',
|
340 |
+
"nl-facebook-mms VITS": "facebook/mms-tts-nld",
|
341 |
+
"en-facebook-mms VITS": "facebook/mms-tts-eng",
|
342 |
+
"fi-facebook-mms VITS": "facebook/mms-tts-fin",
|
343 |
+
"fr-facebook-mms VITS": "facebook/mms-tts-fra",
|
344 |
+
"de-facebook-mms VITS": "facebook/mms-tts-deu",
|
345 |
+
"el-facebook-mms VITS": "facebook/mms-tts-ell",
|
346 |
+
"el_Ancient-facebook-mms VITS": "facebook/mms-tts-grc",
|
347 |
+
"he-facebook-mms VITS": "facebook/mms-tts-heb",
|
348 |
+
"hu-facebook-mms VITS": "facebook/mms-tts-hun",
|
349 |
+
# 'it-facebook-mms VITS': 'facebook/mms-tts-ita',
|
350 |
+
# 'ja-facebook-mms VITS': 'facebook/mms-tts-jpn',
|
351 |
+
"ko-facebook-mms VITS": "facebook/mms-tts-kor",
|
352 |
+
"fa-facebook-mms VITS": "facebook/mms-tts-fas",
|
353 |
+
"pl-facebook-mms VITS": "facebook/mms-tts-pol",
|
354 |
+
"pt-facebook-mms VITS": "facebook/mms-tts-por",
|
355 |
+
"ru-facebook-mms VITS": "facebook/mms-tts-rus",
|
356 |
+
"es-facebook-mms VITS": "facebook/mms-tts-spa",
|
357 |
+
"tr-facebook-mms VITS": "facebook/mms-tts-tur",
|
358 |
+
"uk-facebook-mms VITS": "facebook/mms-tts-ukr",
|
359 |
+
"ur_arabic-facebook-mms VITS": "facebook/mms-tts-urd-script_arabic",
|
360 |
+
"ur_devanagari-facebook-mms VITS": "facebook/mms-tts-urd-script_devanagari",
|
361 |
+
"ur_latin-facebook-mms VITS": "facebook/mms-tts-urd-script_latin",
|
362 |
+
"vi-facebook-mms VITS": "facebook/mms-tts-vie",
|
363 |
+
"hi-facebook-mms VITS": "facebook/mms-tts-hin",
|
364 |
+
"hi_Fiji-facebook-mms VITS": "facebook/mms-tts-hif",
|
365 |
+
"id-facebook-mms VITS": "facebook/mms-tts-ind",
|
366 |
+
"bn-facebook-mms VITS": "facebook/mms-tts-ben",
|
367 |
+
"te-facebook-mms VITS": "facebook/mms-tts-tel",
|
368 |
+
"mr-facebook-mms VITS": "facebook/mms-tts-mar",
|
369 |
+
"ta-facebook-mms VITS": "facebook/mms-tts-tam",
|
370 |
+
"jw-facebook-mms VITS": "facebook/mms-tts-jav",
|
371 |
+
"jw_Suriname-facebook-mms VITS": "facebook/mms-tts-jvn",
|
372 |
+
"ca-facebook-mms VITS": "facebook/mms-tts-cat",
|
373 |
+
"ne-facebook-mms VITS": "facebook/mms-tts-nep",
|
374 |
+
"th-facebook-mms VITS": "facebook/mms-tts-tha",
|
375 |
+
"th_Northern-facebook-mms VITS": "facebook/mms-tts-nod",
|
376 |
+
"sv-facebook-mms VITS": "facebook/mms-tts-swe",
|
377 |
+
"am-facebook-mms VITS": "facebook/mms-tts-amh",
|
378 |
+
"cy-facebook-mms VITS": "facebook/mms-tts-cym",
|
379 |
+
# "et-facebook-mms VITS": "facebook/mms-tts-est",
|
380 |
+
# "ht-facebook-mms VITS": "facebook/mms-tts-hrv",
|
381 |
+
"is-facebook-mms VITS": "facebook/mms-tts-isl",
|
382 |
+
"km-facebook-mms VITS": "facebook/mms-tts-khm",
|
383 |
+
"km_Northern-facebook-mms VITS": "facebook/mms-tts-kxm",
|
384 |
+
# "sk-facebook-mms VITS": "facebook/mms-tts-slk",
|
385 |
+
"sq_Northern-facebook-mms VITS": "facebook/mms-tts-sqi",
|
386 |
+
"az_South-facebook-mms VITS": "facebook/mms-tts-azb",
|
387 |
+
"az_North_script_cyrillic-facebook-mms VITS": "facebook/mms-tts-azj-script_cyrillic",
|
388 |
+
"az_North_script_latin-facebook-mms VITS": "facebook/mms-tts-azj-script_latin",
|
389 |
+
"bg-facebook-mms VITS": "facebook/mms-tts-bul",
|
390 |
+
# "gl-facebook-mms VITS": "facebook/mms-tts-glg",
|
391 |
+
"gu-facebook-mms VITS": "facebook/mms-tts-guj",
|
392 |
+
"kk-facebook-mms VITS": "facebook/mms-tts-kaz",
|
393 |
+
"kn-facebook-mms VITS": "facebook/mms-tts-kan",
|
394 |
+
# "lt-facebook-mms VITS": "facebook/mms-tts-lit",
|
395 |
+
"lv-facebook-mms VITS": "facebook/mms-tts-lav",
|
396 |
+
# "mk-facebook-mms VITS": "facebook/mms-tts-mkd",
|
397 |
+
"ml-facebook-mms VITS": "facebook/mms-tts-mal",
|
398 |
+
"ms-facebook-mms VITS": "facebook/mms-tts-zlm",
|
399 |
+
"ms_Central-facebook-mms VITS": "facebook/mms-tts-pse",
|
400 |
+
"ms_Manado-facebook-mms VITS": "facebook/mms-tts-xmm",
|
401 |
+
"ro-facebook-mms VITS": "facebook/mms-tts-ron",
|
402 |
+
# "si-facebook-mms VITS": "facebook/mms-tts-sin",
|
403 |
+
"sw-facebook-mms VITS": "facebook/mms-tts-swh",
|
404 |
+
# "af-facebook-mms VITS": "facebook/mms-tts-afr",
|
405 |
+
# "bs-facebook-mms VITS": "facebook/mms-tts-bos",
|
406 |
+
"la-facebook-mms VITS": "facebook/mms-tts-lat",
|
407 |
+
"my-facebook-mms VITS": "facebook/mms-tts-mya",
|
408 |
+
# "no_Bokmål-facebook-mms VITS": "thomasht86/mms-tts-nob", # verify
|
409 |
+
"as-facebook-mms VITS": "facebook/mms-tts-asm",
|
410 |
+
"as_Nagamese-facebook-mms VITS": "facebook/mms-tts-nag",
|
411 |
+
"eu-facebook-mms VITS": "facebook/mms-tts-eus",
|
412 |
+
"ha-facebook-mms VITS": "facebook/mms-tts-hau",
|
413 |
+
"ht-facebook-mms VITS": "facebook/mms-tts-hat",
|
414 |
+
"hy_Western-facebook-mms VITS": "facebook/mms-tts-hyw",
|
415 |
+
"lo-facebook-mms VITS": "facebook/mms-tts-lao",
|
416 |
+
"mg-facebook-mms VITS": "facebook/mms-tts-mlg",
|
417 |
+
"mn-facebook-mms VITS": "facebook/mms-tts-mon",
|
418 |
+
# "mt-facebook-mms VITS": "facebook/mms-tts-mlt",
|
419 |
+
"pa_Eastern-facebook-mms VITS": "facebook/mms-tts-pan",
|
420 |
+
# "pa_Western-facebook-mms VITS": "facebook/mms-tts-pnb",
|
421 |
+
# "ps-facebook-mms VITS": "facebook/mms-tts-pus",
|
422 |
+
# "sl-facebook-mms VITS": "facebook/mms-tts-slv",
|
423 |
+
"sn-facebook-mms VITS": "facebook/mms-tts-sna",
|
424 |
+
"so-facebook-mms VITS": "facebook/mms-tts-son",
|
425 |
+
"tg-facebook-mms VITS": "facebook/mms-tts-tgk",
|
426 |
+
"tk_script_arabic-facebook-mms VITS": "facebook/mms-tts-tuk-script_arabic",
|
427 |
+
"tk_script_latin-facebook-mms VITS": "facebook/mms-tts-tuk-script_latin",
|
428 |
+
"tt-facebook-mms VITS": "facebook/mms-tts-tat",
|
429 |
+
"tt_Crimean-facebook-mms VITS": "facebook/mms-tts-crh",
|
430 |
+
"uz_script_cyrillic-facebook-mms VITS": "facebook/mms-tts-uzb-script_cyrillic",
|
431 |
+
"yo-facebook-mms VITS": "facebook/mms-tts-yor",
|
432 |
+
"ay-facebook-mms VITS": "facebook/mms-tts-ayr",
|
433 |
+
"bm-facebook-mms VITS": "facebook/mms-tts-bam",
|
434 |
+
"ceb-facebook-mms VITS": "facebook/mms-tts-ceb",
|
435 |
+
"ny-facebook-mms VITS": "facebook/mms-tts-nya",
|
436 |
+
"dv-facebook-mms VITS": "facebook/mms-tts-div",
|
437 |
+
"doi-facebook-mms VITS": "facebook/mms-tts-dgo",
|
438 |
+
"ee-facebook-mms VITS": "facebook/mms-tts-ewe",
|
439 |
+
"gn-facebook-mms VITS": "facebook/mms-tts-grn",
|
440 |
+
"ilo-facebook-mms VITS": "facebook/mms-tts-ilo",
|
441 |
+
"rw-facebook-mms VITS": "facebook/mms-tts-kin",
|
442 |
+
"kri-facebook-mms VITS": "facebook/mms-tts-kri",
|
443 |
+
"ku_script_arabic-facebook-mms VITS": "facebook/mms-tts-kmr-script_arabic",
|
444 |
+
"ku_script_cyrillic-facebook-mms VITS": "facebook/mms-tts-kmr-script_cyrillic",
|
445 |
+
"ku_script_latin-facebook-mms VITS": "facebook/mms-tts-kmr-script_latin",
|
446 |
+
"ckb-facebook-mms VITS": "razhan/mms-tts-ckb", # Verify w
|
447 |
+
"ky-facebook-mms VITS": "facebook/mms-tts-kir",
|
448 |
+
"lg-facebook-mms VITS": "facebook/mms-tts-lug",
|
449 |
+
"mai-facebook-mms VITS": "facebook/mms-tts-mai",
|
450 |
+
"or-facebook-mms VITS": "facebook/mms-tts-ory",
|
451 |
+
"om-facebook-mms VITS": "facebook/mms-tts-orm",
|
452 |
+
"qu_Huallaga-facebook-mms VITS": "facebook/mms-tts-qub",
|
453 |
+
"qu_Lambayeque-facebook-mms VITS": "facebook/mms-tts-quf",
|
454 |
+
"qu_South_Bolivian-facebook-mms VITS": "facebook/mms-tts-quh",
|
455 |
+
"qu_North_Bolivian-facebook-mms VITS": "facebook/mms-tts-qul",
|
456 |
+
"qu_Tena_Lowland-facebook-mms VITS": "facebook/mms-tts-quw",
|
457 |
+
"qu_Ayacucho-facebook-mms VITS": "facebook/mms-tts-quy",
|
458 |
+
"qu_Cusco-facebook-mms VITS": "facebook/mms-tts-quz",
|
459 |
+
"qu_Cajamarca-facebook-mms VITS": "facebook/mms-tts-qvc",
|
460 |
+
"qu_Eastern_Apurímac-facebook-mms VITS": "facebook/mms-tts-qve",
|
461 |
+
"qu_Huamalíes_Dos_de_Mayo_Huánuco-facebook-mms VITS": "facebook/mms-tts-qvh",
|
462 |
+
"qu_Margos_Yarowilca_Lauricocha-facebook-mms VITS": "facebook/mms-tts-qvm",
|
463 |
+
"qu_North_Junín-facebook-mms VITS": "facebook/mms-tts-qvn",
|
464 |
+
"qu_Napo-facebook-mms VITS": "facebook/mms-tts-qvo",
|
465 |
+
"qu_San_Martín-facebook-mms VITS": "facebook/mms-tts-qvs",
|
466 |
+
"qu_Huaylla_Wanca-facebook-mms VITS": "facebook/mms-tts-qvw",
|
467 |
+
"qu_Northern_Pastaza-facebook-mms VITS": "facebook/mms-tts-qvz",
|
468 |
+
"qu_Huaylas_Ancash-facebook-mms VITS": "facebook/mms-tts-qwh",
|
469 |
+
"qu_Panao-facebook-mms VITS": "facebook/mms-tts-qxh",
|
470 |
+
"qu_Salasaca_Highland-facebook-mms VITS": "facebook/mms-tts-qxl",
|
471 |
+
"qu_Northern_Conchucos_Ancash-facebook-mms VITS": "facebook/mms-tts-qxn",
|
472 |
+
"qu_Southern_Conchucos-facebook-mms VITS": "facebook/mms-tts-qxo",
|
473 |
+
"qu_Cañar_Highland-facebook-mms VITS": "facebook/mms-tts-qxr",
|
474 |
+
"sm-facebook-mms VITS": "facebook/mms-tts-smo",
|
475 |
+
"ti-facebook-mms VITS": "facebook/mms-tts-tir",
|
476 |
+
"ts-facebook-mms VITS": "facebook/mms-tts-tso",
|
477 |
+
"ak-facebook-mms VITS": "facebook/mms-tts-aka",
|
478 |
+
"ug_script_arabic-facebook-mms VITS": "facebook/mms-tts-uig-script_arabic",
|
479 |
+
"ug_script_cyrillic-facebook-mms VITS": "facebook/mms-tts-uig-script_cyrillic",
|
480 |
+
}
|
481 |
+
|
482 |
+
OPENAI_TTS_CODES = [
|
483 |
+
"af", "ar", "hy", "az", "be", "bs", "bg", "ca", "zh", "hr", "cs", "da",
|
484 |
+
"nl", "en", "et", "fi", "fr", "gl", "de", "el", "he", "hi", "hu", "is",
|
485 |
+
"id", "it", "ja", "kn", "kk", "ko", "lv", "lt", "mk", "ms", "mr", "mi",
|
486 |
+
"ne", "no", "fa", "pl", "pt", "ro", "ru", "sr", "sk", "sl", "es", "sw",
|
487 |
+
"sv", "tl", "ta", "th", "tr", "uk", "ur", "vi", "cy", "zh-TW"
|
488 |
+
]
|
489 |
+
|
490 |
+
OPENAI_TTS_MODELS = [
|
491 |
+
">alloy OpenAI-TTS",
|
492 |
+
">echo OpenAI-TTS",
|
493 |
+
">fable OpenAI-TTS",
|
494 |
+
">onyx OpenAI-TTS",
|
495 |
+
">nova OpenAI-TTS",
|
496 |
+
">shimmer OpenAI-TTS",
|
497 |
+
">alloy HD OpenAI-TTS",
|
498 |
+
">echo HD OpenAI-TTS",
|
499 |
+
">fable HD OpenAI-TTS",
|
500 |
+
">onyx HD OpenAI-TTS",
|
501 |
+
">nova HD OpenAI-TTS",
|
502 |
+
">shimmer HD OpenAI-TTS"
|
503 |
+
]
|
504 |
+
|
505 |
+
LANGUAGE_CODE_IN_THREE_LETTERS = {
|
506 |
+
"Automatic detection": "aut",
|
507 |
+
"ar": "ara",
|
508 |
+
"zh": "chi",
|
509 |
+
"cs": "cze",
|
510 |
+
"da": "dan",
|
511 |
+
"nl": "dut",
|
512 |
+
"en": "eng",
|
513 |
+
"fi": "fin",
|
514 |
+
"fr": "fre",
|
515 |
+
"de": "ger",
|
516 |
+
"el": "gre",
|
517 |
+
"he": "heb",
|
518 |
+
"hu": "hun",
|
519 |
+
"it": "ita",
|
520 |
+
"ja": "jpn",
|
521 |
+
"ko": "kor",
|
522 |
+
"fa": "per",
|
523 |
+
"pl": "pol",
|
524 |
+
"pt": "por",
|
525 |
+
"ru": "rus",
|
526 |
+
"es": "spa",
|
527 |
+
"tr": "tur",
|
528 |
+
"uk": "ukr",
|
529 |
+
"ur": "urd",
|
530 |
+
"vi": "vie",
|
531 |
+
"hi": "hin",
|
532 |
+
"id": "ind",
|
533 |
+
"bn": "ben",
|
534 |
+
"te": "tel",
|
535 |
+
"mr": "mar",
|
536 |
+
"ta": "tam",
|
537 |
+
"jw": "jav",
|
538 |
+
"ca": "cat",
|
539 |
+
"ne": "nep",
|
540 |
+
"th": "tha",
|
541 |
+
"sv": "swe",
|
542 |
+
"am": "amh",
|
543 |
+
"cy": "cym",
|
544 |
+
"et": "est",
|
545 |
+
"hr": "hrv",
|
546 |
+
"is": "isl",
|
547 |
+
"km": "khm",
|
548 |
+
"sk": "slk",
|
549 |
+
"sq": "sqi",
|
550 |
+
"sr": "srp",
|
551 |
+
}
|
soni_translate/languages_gui.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
soni_translate/logging_setup.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import sys
|
3 |
+
import warnings
|
4 |
+
import os
|
5 |
+
|
6 |
+
|
7 |
+
def configure_logging_libs(debug=False):
|
8 |
+
warnings.filterwarnings(
|
9 |
+
action="ignore", category=UserWarning, module="pyannote"
|
10 |
+
)
|
11 |
+
modules = [
|
12 |
+
"numba", "httpx", "markdown_it", "speechbrain", "fairseq", "pyannote",
|
13 |
+
"faiss",
|
14 |
+
"pytorch_lightning.utilities.migration.utils",
|
15 |
+
"pytorch_lightning.utilities.migration",
|
16 |
+
"pytorch_lightning",
|
17 |
+
"lightning",
|
18 |
+
"lightning.pytorch.utilities.migration.utils",
|
19 |
+
]
|
20 |
+
try:
|
21 |
+
for module in modules:
|
22 |
+
logging.getLogger(module).setLevel(logging.WARNING)
|
23 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = "3" if not debug else "1"
|
24 |
+
|
25 |
+
# fix verbose pyannote audio
|
26 |
+
def fix_verbose_pyannote(*args, what=""):
|
27 |
+
pass
|
28 |
+
import pyannote.audio.core.model # noqa
|
29 |
+
pyannote.audio.core.model.check_version = fix_verbose_pyannote
|
30 |
+
except Exception as error:
|
31 |
+
logger.error(str(error))
|
32 |
+
|
33 |
+
|
34 |
+
def setup_logger(name_log):
|
35 |
+
logger = logging.getLogger(name_log)
|
36 |
+
logger.setLevel(logging.INFO)
|
37 |
+
|
38 |
+
_default_handler = logging.StreamHandler() # Set sys.stderr as stream.
|
39 |
+
_default_handler.flush = sys.stderr.flush
|
40 |
+
logger.addHandler(_default_handler)
|
41 |
+
|
42 |
+
logger.propagate = False
|
43 |
+
|
44 |
+
handlers = logger.handlers
|
45 |
+
|
46 |
+
for handler in handlers:
|
47 |
+
formatter = logging.Formatter("[%(levelname)s] >> %(message)s")
|
48 |
+
handler.setFormatter(formatter)
|
49 |
+
|
50 |
+
# logger.handlers
|
51 |
+
|
52 |
+
return logger
|
53 |
+
|
54 |
+
|
55 |
+
logger = setup_logger("sonitranslate")
|
56 |
+
logger.setLevel(logging.INFO)
|
57 |
+
|
58 |
+
|
59 |
+
def set_logging_level(verbosity_level):
|
60 |
+
logging_level_mapping = {
|
61 |
+
"debug": logging.DEBUG,
|
62 |
+
"info": logging.INFO,
|
63 |
+
"warning": logging.WARNING,
|
64 |
+
"error": logging.ERROR,
|
65 |
+
"critical": logging.CRITICAL,
|
66 |
+
}
|
67 |
+
|
68 |
+
logger.setLevel(logging_level_mapping.get(verbosity_level, logging.INFO))
|
soni_translate/mdx_net.py
ADDED
@@ -0,0 +1,582 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import hashlib
|
3 |
+
import os
|
4 |
+
import queue
|
5 |
+
import threading
|
6 |
+
import json
|
7 |
+
import shlex
|
8 |
+
import sys
|
9 |
+
import subprocess
|
10 |
+
import librosa
|
11 |
+
import numpy as np
|
12 |
+
import soundfile as sf
|
13 |
+
import torch
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
try:
|
17 |
+
from .utils import (
|
18 |
+
remove_directory_contents,
|
19 |
+
create_directories,
|
20 |
+
)
|
21 |
+
except: # noqa
|
22 |
+
from utils import (
|
23 |
+
remove_directory_contents,
|
24 |
+
create_directories,
|
25 |
+
)
|
26 |
+
from .logging_setup import logger
|
27 |
+
|
28 |
+
try:
|
29 |
+
import onnxruntime as ort
|
30 |
+
except Exception as error:
|
31 |
+
logger.error(str(error))
|
32 |
+
# import warnings
|
33 |
+
# warnings.filterwarnings("ignore")
|
34 |
+
|
35 |
+
stem_naming = {
|
36 |
+
"Vocals": "Instrumental",
|
37 |
+
"Other": "Instruments",
|
38 |
+
"Instrumental": "Vocals",
|
39 |
+
"Drums": "Drumless",
|
40 |
+
"Bass": "Bassless",
|
41 |
+
}
|
42 |
+
|
43 |
+
|
44 |
+
class MDXModel:
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
device,
|
48 |
+
dim_f,
|
49 |
+
dim_t,
|
50 |
+
n_fft,
|
51 |
+
hop=1024,
|
52 |
+
stem_name=None,
|
53 |
+
compensation=1.000,
|
54 |
+
):
|
55 |
+
self.dim_f = dim_f
|
56 |
+
self.dim_t = dim_t
|
57 |
+
self.dim_c = 4
|
58 |
+
self.n_fft = n_fft
|
59 |
+
self.hop = hop
|
60 |
+
self.stem_name = stem_name
|
61 |
+
self.compensation = compensation
|
62 |
+
|
63 |
+
self.n_bins = self.n_fft // 2 + 1
|
64 |
+
self.chunk_size = hop * (self.dim_t - 1)
|
65 |
+
self.window = torch.hann_window(
|
66 |
+
window_length=self.n_fft, periodic=True
|
67 |
+
).to(device)
|
68 |
+
|
69 |
+
out_c = self.dim_c
|
70 |
+
|
71 |
+
self.freq_pad = torch.zeros(
|
72 |
+
[1, out_c, self.n_bins - self.dim_f, self.dim_t]
|
73 |
+
).to(device)
|
74 |
+
|
75 |
+
def stft(self, x):
|
76 |
+
x = x.reshape([-1, self.chunk_size])
|
77 |
+
x = torch.stft(
|
78 |
+
x,
|
79 |
+
n_fft=self.n_fft,
|
80 |
+
hop_length=self.hop,
|
81 |
+
window=self.window,
|
82 |
+
center=True,
|
83 |
+
return_complex=True,
|
84 |
+
)
|
85 |
+
x = torch.view_as_real(x)
|
86 |
+
x = x.permute([0, 3, 1, 2])
|
87 |
+
x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape(
|
88 |
+
[-1, 4, self.n_bins, self.dim_t]
|
89 |
+
)
|
90 |
+
return x[:, :, : self.dim_f]
|
91 |
+
|
92 |
+
def istft(self, x, freq_pad=None):
|
93 |
+
freq_pad = (
|
94 |
+
self.freq_pad.repeat([x.shape[0], 1, 1, 1])
|
95 |
+
if freq_pad is None
|
96 |
+
else freq_pad
|
97 |
+
)
|
98 |
+
x = torch.cat([x, freq_pad], -2)
|
99 |
+
# c = 4*2 if self.target_name=='*' else 2
|
100 |
+
x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape(
|
101 |
+
[-1, 2, self.n_bins, self.dim_t]
|
102 |
+
)
|
103 |
+
x = x.permute([0, 2, 3, 1])
|
104 |
+
x = x.contiguous()
|
105 |
+
x = torch.view_as_complex(x)
|
106 |
+
x = torch.istft(
|
107 |
+
x,
|
108 |
+
n_fft=self.n_fft,
|
109 |
+
hop_length=self.hop,
|
110 |
+
window=self.window,
|
111 |
+
center=True,
|
112 |
+
)
|
113 |
+
return x.reshape([-1, 2, self.chunk_size])
|
114 |
+
|
115 |
+
|
116 |
+
class MDX:
|
117 |
+
DEFAULT_SR = 44100
|
118 |
+
# Unit: seconds
|
119 |
+
DEFAULT_CHUNK_SIZE = 0 * DEFAULT_SR
|
120 |
+
DEFAULT_MARGIN_SIZE = 1 * DEFAULT_SR
|
121 |
+
|
122 |
+
def __init__(
|
123 |
+
self, model_path: str, params: MDXModel, processor=0
|
124 |
+
):
|
125 |
+
# Set the device and the provider (CPU or CUDA)
|
126 |
+
self.device = (
|
127 |
+
torch.device(f"cuda:{processor}")
|
128 |
+
if processor >= 0
|
129 |
+
else torch.device("cpu")
|
130 |
+
)
|
131 |
+
self.provider = (
|
132 |
+
["CUDAExecutionProvider"]
|
133 |
+
if processor >= 0
|
134 |
+
else ["CPUExecutionProvider"]
|
135 |
+
)
|
136 |
+
|
137 |
+
self.model = params
|
138 |
+
|
139 |
+
# Load the ONNX model using ONNX Runtime
|
140 |
+
self.ort = ort.InferenceSession(model_path, providers=self.provider)
|
141 |
+
# Preload the model for faster performance
|
142 |
+
self.ort.run(
|
143 |
+
None,
|
144 |
+
{"input": torch.rand(1, 4, params.dim_f, params.dim_t).numpy()},
|
145 |
+
)
|
146 |
+
self.process = lambda spec: self.ort.run(
|
147 |
+
None, {"input": spec.cpu().numpy()}
|
148 |
+
)[0]
|
149 |
+
|
150 |
+
self.prog = None
|
151 |
+
|
152 |
+
@staticmethod
|
153 |
+
def get_hash(model_path):
|
154 |
+
try:
|
155 |
+
with open(model_path, "rb") as f:
|
156 |
+
f.seek(-10000 * 1024, 2)
|
157 |
+
model_hash = hashlib.md5(f.read()).hexdigest()
|
158 |
+
except: # noqa
|
159 |
+
model_hash = hashlib.md5(open(model_path, "rb").read()).hexdigest()
|
160 |
+
|
161 |
+
return model_hash
|
162 |
+
|
163 |
+
@staticmethod
|
164 |
+
def segment(
|
165 |
+
wave,
|
166 |
+
combine=True,
|
167 |
+
chunk_size=DEFAULT_CHUNK_SIZE,
|
168 |
+
margin_size=DEFAULT_MARGIN_SIZE,
|
169 |
+
):
|
170 |
+
"""
|
171 |
+
Segment or join segmented wave array
|
172 |
+
|
173 |
+
Args:
|
174 |
+
wave: (np.array) Wave array to be segmented or joined
|
175 |
+
combine: (bool) If True, combines segmented wave array.
|
176 |
+
If False, segments wave array.
|
177 |
+
chunk_size: (int) Size of each segment (in samples)
|
178 |
+
margin_size: (int) Size of margin between segments (in samples)
|
179 |
+
|
180 |
+
Returns:
|
181 |
+
numpy array: Segmented or joined wave array
|
182 |
+
"""
|
183 |
+
|
184 |
+
if combine:
|
185 |
+
# Initializing as None instead of [] for later numpy array concatenation
|
186 |
+
processed_wave = None
|
187 |
+
for segment_count, segment in enumerate(wave):
|
188 |
+
start = 0 if segment_count == 0 else margin_size
|
189 |
+
end = None if segment_count == len(wave) - 1 else -margin_size
|
190 |
+
if margin_size == 0:
|
191 |
+
end = None
|
192 |
+
if processed_wave is None: # Create array for first segment
|
193 |
+
processed_wave = segment[:, start:end]
|
194 |
+
else: # Concatenate to existing array for subsequent segments
|
195 |
+
processed_wave = np.concatenate(
|
196 |
+
(processed_wave, segment[:, start:end]), axis=-1
|
197 |
+
)
|
198 |
+
|
199 |
+
else:
|
200 |
+
processed_wave = []
|
201 |
+
sample_count = wave.shape[-1]
|
202 |
+
|
203 |
+
if chunk_size <= 0 or chunk_size > sample_count:
|
204 |
+
chunk_size = sample_count
|
205 |
+
|
206 |
+
if margin_size > chunk_size:
|
207 |
+
margin_size = chunk_size
|
208 |
+
|
209 |
+
for segment_count, skip in enumerate(
|
210 |
+
range(0, sample_count, chunk_size)
|
211 |
+
):
|
212 |
+
margin = 0 if segment_count == 0 else margin_size
|
213 |
+
end = min(skip + chunk_size + margin_size, sample_count)
|
214 |
+
start = skip - margin
|
215 |
+
|
216 |
+
cut = wave[:, start:end].copy()
|
217 |
+
processed_wave.append(cut)
|
218 |
+
|
219 |
+
if end == sample_count:
|
220 |
+
break
|
221 |
+
|
222 |
+
return processed_wave
|
223 |
+
|
224 |
+
def pad_wave(self, wave):
|
225 |
+
"""
|
226 |
+
Pad the wave array to match the required chunk size
|
227 |
+
|
228 |
+
Args:
|
229 |
+
wave: (np.array) Wave array to be padded
|
230 |
+
|
231 |
+
Returns:
|
232 |
+
tuple: (padded_wave, pad, trim)
|
233 |
+
- padded_wave: Padded wave array
|
234 |
+
- pad: Number of samples that were padded
|
235 |
+
- trim: Number of samples that were trimmed
|
236 |
+
"""
|
237 |
+
n_sample = wave.shape[1]
|
238 |
+
trim = self.model.n_fft // 2
|
239 |
+
gen_size = self.model.chunk_size - 2 * trim
|
240 |
+
pad = gen_size - n_sample % gen_size
|
241 |
+
|
242 |
+
# Padded wave
|
243 |
+
wave_p = np.concatenate(
|
244 |
+
(
|
245 |
+
np.zeros((2, trim)),
|
246 |
+
wave,
|
247 |
+
np.zeros((2, pad)),
|
248 |
+
np.zeros((2, trim)),
|
249 |
+
),
|
250 |
+
1,
|
251 |
+
)
|
252 |
+
|
253 |
+
mix_waves = []
|
254 |
+
for i in range(0, n_sample + pad, gen_size):
|
255 |
+
waves = np.array(wave_p[:, i:i + self.model.chunk_size])
|
256 |
+
mix_waves.append(waves)
|
257 |
+
|
258 |
+
mix_waves = torch.tensor(mix_waves, dtype=torch.float32).to(
|
259 |
+
self.device
|
260 |
+
)
|
261 |
+
|
262 |
+
return mix_waves, pad, trim
|
263 |
+
|
264 |
+
def _process_wave(self, mix_waves, trim, pad, q: queue.Queue, _id: int):
|
265 |
+
"""
|
266 |
+
Process each wave segment in a multi-threaded environment
|
267 |
+
|
268 |
+
Args:
|
269 |
+
mix_waves: (torch.Tensor) Wave segments to be processed
|
270 |
+
trim: (int) Number of samples trimmed during padding
|
271 |
+
pad: (int) Number of samples padded during padding
|
272 |
+
q: (queue.Queue) Queue to hold the processed wave segments
|
273 |
+
_id: (int) Identifier of the processed wave segment
|
274 |
+
|
275 |
+
Returns:
|
276 |
+
numpy array: Processed wave segment
|
277 |
+
"""
|
278 |
+
mix_waves = mix_waves.split(1)
|
279 |
+
with torch.no_grad():
|
280 |
+
pw = []
|
281 |
+
for mix_wave in mix_waves:
|
282 |
+
self.prog.update()
|
283 |
+
spec = self.model.stft(mix_wave)
|
284 |
+
processed_spec = torch.tensor(self.process(spec))
|
285 |
+
processed_wav = self.model.istft(
|
286 |
+
processed_spec.to(self.device)
|
287 |
+
)
|
288 |
+
processed_wav = (
|
289 |
+
processed_wav[:, :, trim:-trim]
|
290 |
+
.transpose(0, 1)
|
291 |
+
.reshape(2, -1)
|
292 |
+
.cpu()
|
293 |
+
.numpy()
|
294 |
+
)
|
295 |
+
pw.append(processed_wav)
|
296 |
+
processed_signal = np.concatenate(pw, axis=-1)[:, :-pad]
|
297 |
+
q.put({_id: processed_signal})
|
298 |
+
return processed_signal
|
299 |
+
|
300 |
+
def process_wave(self, wave: np.array, mt_threads=1):
|
301 |
+
"""
|
302 |
+
Process the wave array in a multi-threaded environment
|
303 |
+
|
304 |
+
Args:
|
305 |
+
wave: (np.array) Wave array to be processed
|
306 |
+
mt_threads: (int) Number of threads to be used for processing
|
307 |
+
|
308 |
+
Returns:
|
309 |
+
numpy array: Processed wave array
|
310 |
+
"""
|
311 |
+
self.prog = tqdm(total=0)
|
312 |
+
chunk = wave.shape[-1] // mt_threads
|
313 |
+
waves = self.segment(wave, False, chunk)
|
314 |
+
|
315 |
+
# Create a queue to hold the processed wave segments
|
316 |
+
q = queue.Queue()
|
317 |
+
threads = []
|
318 |
+
for c, batch in enumerate(waves):
|
319 |
+
mix_waves, pad, trim = self.pad_wave(batch)
|
320 |
+
self.prog.total = len(mix_waves) * mt_threads
|
321 |
+
thread = threading.Thread(
|
322 |
+
target=self._process_wave, args=(mix_waves, trim, pad, q, c)
|
323 |
+
)
|
324 |
+
thread.start()
|
325 |
+
threads.append(thread)
|
326 |
+
for thread in threads:
|
327 |
+
thread.join()
|
328 |
+
self.prog.close()
|
329 |
+
|
330 |
+
processed_batches = []
|
331 |
+
while not q.empty():
|
332 |
+
processed_batches.append(q.get())
|
333 |
+
processed_batches = [
|
334 |
+
list(wave.values())[0]
|
335 |
+
for wave in sorted(
|
336 |
+
processed_batches, key=lambda d: list(d.keys())[0]
|
337 |
+
)
|
338 |
+
]
|
339 |
+
assert len(processed_batches) == len(
|
340 |
+
waves
|
341 |
+
), "Incomplete processed batches, please reduce batch size!"
|
342 |
+
return self.segment(processed_batches, True, chunk)
|
343 |
+
|
344 |
+
|
345 |
+
def run_mdx(
|
346 |
+
model_params,
|
347 |
+
output_dir,
|
348 |
+
model_path,
|
349 |
+
filename,
|
350 |
+
exclude_main=False,
|
351 |
+
exclude_inversion=False,
|
352 |
+
suffix=None,
|
353 |
+
invert_suffix=None,
|
354 |
+
denoise=False,
|
355 |
+
keep_orig=True,
|
356 |
+
m_threads=2,
|
357 |
+
device_base="cuda",
|
358 |
+
):
|
359 |
+
if device_base == "cuda":
|
360 |
+
device = torch.device("cuda:0")
|
361 |
+
processor_num = 0
|
362 |
+
device_properties = torch.cuda.get_device_properties(device)
|
363 |
+
vram_gb = device_properties.total_memory / 1024**3
|
364 |
+
m_threads = 1 if vram_gb < 8 else 2
|
365 |
+
else:
|
366 |
+
device = torch.device("cpu")
|
367 |
+
processor_num = -1
|
368 |
+
m_threads = 1
|
369 |
+
|
370 |
+
model_hash = MDX.get_hash(model_path)
|
371 |
+
mp = model_params.get(model_hash)
|
372 |
+
model = MDXModel(
|
373 |
+
device,
|
374 |
+
dim_f=mp["mdx_dim_f_set"],
|
375 |
+
dim_t=2 ** mp["mdx_dim_t_set"],
|
376 |
+
n_fft=mp["mdx_n_fft_scale_set"],
|
377 |
+
stem_name=mp["primary_stem"],
|
378 |
+
compensation=mp["compensate"],
|
379 |
+
)
|
380 |
+
|
381 |
+
mdx_sess = MDX(model_path, model, processor=processor_num)
|
382 |
+
wave, sr = librosa.load(filename, mono=False, sr=44100)
|
383 |
+
# normalizing input wave gives better output
|
384 |
+
peak = max(np.max(wave), abs(np.min(wave)))
|
385 |
+
wave /= peak
|
386 |
+
if denoise:
|
387 |
+
wave_processed = -(mdx_sess.process_wave(-wave, m_threads)) + (
|
388 |
+
mdx_sess.process_wave(wave, m_threads)
|
389 |
+
)
|
390 |
+
wave_processed *= 0.5
|
391 |
+
else:
|
392 |
+
wave_processed = mdx_sess.process_wave(wave, m_threads)
|
393 |
+
# return to previous peak
|
394 |
+
wave_processed *= peak
|
395 |
+
stem_name = model.stem_name if suffix is None else suffix
|
396 |
+
|
397 |
+
main_filepath = None
|
398 |
+
if not exclude_main:
|
399 |
+
main_filepath = os.path.join(
|
400 |
+
output_dir,
|
401 |
+
f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav",
|
402 |
+
)
|
403 |
+
sf.write(main_filepath, wave_processed.T, sr)
|
404 |
+
|
405 |
+
invert_filepath = None
|
406 |
+
if not exclude_inversion:
|
407 |
+
diff_stem_name = (
|
408 |
+
stem_naming.get(stem_name)
|
409 |
+
if invert_suffix is None
|
410 |
+
else invert_suffix
|
411 |
+
)
|
412 |
+
stem_name = (
|
413 |
+
f"{stem_name}_diff" if diff_stem_name is None else diff_stem_name
|
414 |
+
)
|
415 |
+
invert_filepath = os.path.join(
|
416 |
+
output_dir,
|
417 |
+
f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav",
|
418 |
+
)
|
419 |
+
sf.write(
|
420 |
+
invert_filepath,
|
421 |
+
(-wave_processed.T * model.compensation) + wave.T,
|
422 |
+
sr,
|
423 |
+
)
|
424 |
+
|
425 |
+
if not keep_orig:
|
426 |
+
os.remove(filename)
|
427 |
+
|
428 |
+
del mdx_sess, wave_processed, wave
|
429 |
+
gc.collect()
|
430 |
+
torch.cuda.empty_cache()
|
431 |
+
return main_filepath, invert_filepath
|
432 |
+
|
433 |
+
|
434 |
+
MDX_DOWNLOAD_LINK = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/"
|
435 |
+
UVR_MODELS = [
|
436 |
+
"UVR-MDX-NET-Voc_FT.onnx",
|
437 |
+
"UVR_MDXNET_KARA_2.onnx",
|
438 |
+
"Reverb_HQ_By_FoxJoy.onnx",
|
439 |
+
"UVR-MDX-NET-Inst_HQ_4.onnx",
|
440 |
+
]
|
441 |
+
BASE_DIR = "." # os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
442 |
+
mdxnet_models_dir = os.path.join(BASE_DIR, "mdx_models")
|
443 |
+
output_dir = os.path.join(BASE_DIR, "clean_song_output")
|
444 |
+
|
445 |
+
|
446 |
+
def convert_to_stereo_and_wav(audio_path):
|
447 |
+
wave, sr = librosa.load(audio_path, mono=False, sr=44100)
|
448 |
+
|
449 |
+
# check if mono
|
450 |
+
if type(wave[0]) != np.ndarray or audio_path[-4:].lower() != ".wav": # noqa
|
451 |
+
stereo_path = f"{os.path.splitext(audio_path)[0]}_stereo.wav"
|
452 |
+
stereo_path = os.path.join(output_dir, stereo_path)
|
453 |
+
|
454 |
+
command = shlex.split(
|
455 |
+
f'ffmpeg -y -loglevel error -i "{audio_path}" -ac 2 -f wav "{stereo_path}"'
|
456 |
+
)
|
457 |
+
sub_params = {
|
458 |
+
"stdout": subprocess.PIPE,
|
459 |
+
"stderr": subprocess.PIPE,
|
460 |
+
"creationflags": subprocess.CREATE_NO_WINDOW
|
461 |
+
if sys.platform == "win32"
|
462 |
+
else 0,
|
463 |
+
}
|
464 |
+
process_wav = subprocess.Popen(command, **sub_params)
|
465 |
+
output, errors = process_wav.communicate()
|
466 |
+
if process_wav.returncode != 0 or not os.path.exists(stereo_path):
|
467 |
+
raise Exception("Error processing audio to stereo wav")
|
468 |
+
|
469 |
+
return stereo_path
|
470 |
+
else:
|
471 |
+
return audio_path
|
472 |
+
|
473 |
+
|
474 |
+
def process_uvr_task(
|
475 |
+
orig_song_path: str = "aud_test.mp3",
|
476 |
+
main_vocals: bool = False,
|
477 |
+
dereverb: bool = True,
|
478 |
+
song_id: str = "mdx", # folder output name
|
479 |
+
only_voiceless: bool = False,
|
480 |
+
remove_files_output_dir: bool = False,
|
481 |
+
):
|
482 |
+
if os.environ.get("SONITR_DEVICE") == "cpu":
|
483 |
+
device_base = "cpu"
|
484 |
+
else:
|
485 |
+
device_base = "cuda" if torch.cuda.is_available() else "cpu"
|
486 |
+
|
487 |
+
if remove_files_output_dir:
|
488 |
+
remove_directory_contents(output_dir)
|
489 |
+
|
490 |
+
with open(os.path.join(mdxnet_models_dir, "data.json")) as infile:
|
491 |
+
mdx_model_params = json.load(infile)
|
492 |
+
|
493 |
+
song_output_dir = os.path.join(output_dir, song_id)
|
494 |
+
create_directories(song_output_dir)
|
495 |
+
orig_song_path = convert_to_stereo_and_wav(orig_song_path)
|
496 |
+
|
497 |
+
logger.debug(f"onnxruntime device >> {ort.get_device()}")
|
498 |
+
|
499 |
+
if only_voiceless:
|
500 |
+
logger.info("Voiceless Track Separation...")
|
501 |
+
return run_mdx(
|
502 |
+
mdx_model_params,
|
503 |
+
song_output_dir,
|
504 |
+
os.path.join(mdxnet_models_dir, "UVR-MDX-NET-Inst_HQ_4.onnx"),
|
505 |
+
orig_song_path,
|
506 |
+
suffix="Voiceless",
|
507 |
+
denoise=False,
|
508 |
+
keep_orig=True,
|
509 |
+
exclude_inversion=True,
|
510 |
+
device_base=device_base,
|
511 |
+
)
|
512 |
+
|
513 |
+
logger.info("Vocal Track Isolation and Voiceless Track Separation...")
|
514 |
+
vocals_path, instrumentals_path = run_mdx(
|
515 |
+
mdx_model_params,
|
516 |
+
song_output_dir,
|
517 |
+
os.path.join(mdxnet_models_dir, "UVR-MDX-NET-Voc_FT.onnx"),
|
518 |
+
orig_song_path,
|
519 |
+
denoise=True,
|
520 |
+
keep_orig=True,
|
521 |
+
device_base=device_base,
|
522 |
+
)
|
523 |
+
|
524 |
+
if main_vocals:
|
525 |
+
logger.info("Main Voice Separation from Supporting Vocals...")
|
526 |
+
backup_vocals_path, main_vocals_path = run_mdx(
|
527 |
+
mdx_model_params,
|
528 |
+
song_output_dir,
|
529 |
+
os.path.join(mdxnet_models_dir, "UVR_MDXNET_KARA_2.onnx"),
|
530 |
+
vocals_path,
|
531 |
+
suffix="Backup",
|
532 |
+
invert_suffix="Main",
|
533 |
+
denoise=True,
|
534 |
+
device_base=device_base,
|
535 |
+
)
|
536 |
+
else:
|
537 |
+
backup_vocals_path, main_vocals_path = None, vocals_path
|
538 |
+
|
539 |
+
if dereverb:
|
540 |
+
logger.info("Vocal Clarity Enhancement through De-Reverberation...")
|
541 |
+
_, vocals_dereverb_path = run_mdx(
|
542 |
+
mdx_model_params,
|
543 |
+
song_output_dir,
|
544 |
+
os.path.join(mdxnet_models_dir, "Reverb_HQ_By_FoxJoy.onnx"),
|
545 |
+
main_vocals_path,
|
546 |
+
invert_suffix="DeReverb",
|
547 |
+
exclude_main=True,
|
548 |
+
denoise=True,
|
549 |
+
device_base=device_base,
|
550 |
+
)
|
551 |
+
else:
|
552 |
+
vocals_dereverb_path = main_vocals_path
|
553 |
+
|
554 |
+
return (
|
555 |
+
vocals_path,
|
556 |
+
instrumentals_path,
|
557 |
+
backup_vocals_path,
|
558 |
+
main_vocals_path,
|
559 |
+
vocals_dereverb_path,
|
560 |
+
)
|
561 |
+
|
562 |
+
|
563 |
+
if __name__ == "__main__":
|
564 |
+
from utils import download_manager
|
565 |
+
|
566 |
+
for id_model in UVR_MODELS:
|
567 |
+
download_manager(
|
568 |
+
os.path.join(MDX_DOWNLOAD_LINK, id_model), mdxnet_models_dir
|
569 |
+
)
|
570 |
+
(
|
571 |
+
vocals_path_,
|
572 |
+
instrumentals_path_,
|
573 |
+
backup_vocals_path_,
|
574 |
+
main_vocals_path_,
|
575 |
+
vocals_dereverb_path_,
|
576 |
+
) = process_uvr_task(
|
577 |
+
orig_song_path="aud.mp3",
|
578 |
+
main_vocals=True,
|
579 |
+
dereverb=True,
|
580 |
+
song_id="mdx",
|
581 |
+
remove_files_output_dir=True,
|
582 |
+
)
|
soni_translate/postprocessor.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .utils import remove_files, run_command
|
2 |
+
from .text_multiformat_processor import get_subtitle
|
3 |
+
from .logging_setup import logger
|
4 |
+
import unicodedata
|
5 |
+
import shutil
|
6 |
+
import copy
|
7 |
+
import os
|
8 |
+
import re
|
9 |
+
|
10 |
+
OUTPUT_TYPE_OPTIONS = [
|
11 |
+
"video (mp4)",
|
12 |
+
"video (mkv)",
|
13 |
+
"audio (mp3)",
|
14 |
+
"audio (ogg)",
|
15 |
+
"audio (wav)",
|
16 |
+
"subtitle",
|
17 |
+
"subtitle [by speaker]",
|
18 |
+
"video [subtitled] (mp4)",
|
19 |
+
"video [subtitled] (mkv)",
|
20 |
+
"audio [original vocal sound]",
|
21 |
+
"audio [original background sound]",
|
22 |
+
"audio [original vocal and background sound]",
|
23 |
+
"audio [original vocal-dereverb sound]",
|
24 |
+
"audio [original vocal-dereverb and background sound]",
|
25 |
+
"raw media",
|
26 |
+
]
|
27 |
+
|
28 |
+
DOCS_OUTPUT_TYPE_OPTIONS = [
|
29 |
+
"videobook (mp4)",
|
30 |
+
"videobook (mkv)",
|
31 |
+
"audiobook (wav)",
|
32 |
+
"audiobook (mp3)",
|
33 |
+
"audiobook (ogg)",
|
34 |
+
"book (txt)",
|
35 |
+
] # Add DOCX and etc.
|
36 |
+
|
37 |
+
|
38 |
+
def get_no_ext_filename(file_path):
|
39 |
+
file_name_with_extension = os.path.basename(rf"{file_path}")
|
40 |
+
filename_without_extension, _ = os.path.splitext(file_name_with_extension)
|
41 |
+
return filename_without_extension
|
42 |
+
|
43 |
+
|
44 |
+
def get_video_info(link):
|
45 |
+
aux_name = f"video_url_{link}"
|
46 |
+
params_dlp = {"quiet": True, "no_warnings": True, "noplaylist": True}
|
47 |
+
try:
|
48 |
+
from yt_dlp import YoutubeDL
|
49 |
+
|
50 |
+
with YoutubeDL(params_dlp) as ydl:
|
51 |
+
if link.startswith(("www.youtube.com/", "m.youtube.com/")):
|
52 |
+
link = "https://" + link
|
53 |
+
info_dict = ydl.extract_info(link, download=False, process=False)
|
54 |
+
video_id = info_dict.get("id", aux_name)
|
55 |
+
video_title = info_dict.get("title", video_id)
|
56 |
+
if "youtube.com" in link and "&list=" in link:
|
57 |
+
video_title = ydl.extract_info(
|
58 |
+
"https://m.youtube.com/watch?v="+video_id,
|
59 |
+
download=False,
|
60 |
+
process=False
|
61 |
+
).get("title", video_title)
|
62 |
+
except Exception as error:
|
63 |
+
logger.error(str(error))
|
64 |
+
video_title, video_id = aux_name, "NO_ID"
|
65 |
+
return video_title, video_id
|
66 |
+
|
67 |
+
|
68 |
+
def sanitize_file_name(file_name):
|
69 |
+
# Normalize the string to NFKD form to separate combined
|
70 |
+
# characters into base characters and diacritics
|
71 |
+
normalized_name = unicodedata.normalize("NFKD", file_name)
|
72 |
+
# Replace any non-ASCII characters or special symbols with an underscore
|
73 |
+
sanitized_name = re.sub(r"[^\w\s.-]", "_", normalized_name)
|
74 |
+
return sanitized_name
|
75 |
+
|
76 |
+
|
77 |
+
def get_output_file(
|
78 |
+
original_file,
|
79 |
+
new_file_name,
|
80 |
+
soft_subtitles,
|
81 |
+
output_directory="",
|
82 |
+
):
|
83 |
+
directory_base = "." # default directory
|
84 |
+
|
85 |
+
if output_directory and os.path.isdir(output_directory):
|
86 |
+
new_file_path = os.path.join(output_directory, new_file_name)
|
87 |
+
else:
|
88 |
+
new_file_path = os.path.join(directory_base, "outputs", new_file_name)
|
89 |
+
remove_files(new_file_path)
|
90 |
+
|
91 |
+
cm = None
|
92 |
+
if soft_subtitles and original_file.endswith(".mp4"):
|
93 |
+
if new_file_path.endswith(".mp4"):
|
94 |
+
cm = f'ffmpeg -y -i "{original_file}" -i sub_tra.srt -i sub_ori.srt -map 0:v -map 0:a -map 1 -map 2 -c:v copy -c:a copy -c:s mov_text "{new_file_path}"'
|
95 |
+
else:
|
96 |
+
cm = f'ffmpeg -y -i "{original_file}" -i sub_tra.srt -i sub_ori.srt -map 0:v -map 0:a -map 1 -map 2 -c:v copy -c:a copy -c:s srt -movflags use_metadata_tags -map_metadata 0 "{new_file_path}"'
|
97 |
+
elif new_file_path.endswith(".mkv"):
|
98 |
+
cm = f'ffmpeg -i "{original_file}" -c:v copy -c:a copy "{new_file_path}"'
|
99 |
+
elif new_file_path.endswith(".wav") and not original_file.endswith(".wav"):
|
100 |
+
cm = f'ffmpeg -y -i "{original_file}" -acodec pcm_s16le -ar 44100 -ac 2 "{new_file_path}"'
|
101 |
+
elif new_file_path.endswith(".ogg"):
|
102 |
+
cm = f'ffmpeg -i "{original_file}" -c:a libvorbis "{new_file_path}"'
|
103 |
+
elif new_file_path.endswith(".mp3") and not original_file.endswith(".mp3"):
|
104 |
+
cm = f'ffmpeg -y -i "{original_file}" -codec:a libmp3lame -qscale:a 2 "{new_file_path}"'
|
105 |
+
|
106 |
+
if cm:
|
107 |
+
try:
|
108 |
+
run_command(cm)
|
109 |
+
except Exception as error:
|
110 |
+
logger.error(str(error))
|
111 |
+
remove_files(new_file_path)
|
112 |
+
shutil.copy2(original_file, new_file_path)
|
113 |
+
else:
|
114 |
+
shutil.copy2(original_file, new_file_path)
|
115 |
+
|
116 |
+
return os.path.abspath(new_file_path)
|
117 |
+
|
118 |
+
|
119 |
+
def media_out(
|
120 |
+
media_file,
|
121 |
+
lang_code,
|
122 |
+
media_out_name="",
|
123 |
+
extension="mp4",
|
124 |
+
file_obj="video_dub.mp4",
|
125 |
+
soft_subtitles=False,
|
126 |
+
subtitle_files="disable",
|
127 |
+
):
|
128 |
+
if not media_out_name:
|
129 |
+
if os.path.exists(media_file):
|
130 |
+
base_name = get_no_ext_filename(media_file)
|
131 |
+
else:
|
132 |
+
base_name, _ = get_video_info(media_file)
|
133 |
+
|
134 |
+
media_out_name = f"{base_name}__{lang_code}"
|
135 |
+
|
136 |
+
f_name = f"{sanitize_file_name(media_out_name)}.{extension}"
|
137 |
+
|
138 |
+
if subtitle_files != "disable":
|
139 |
+
final_media = [get_output_file(file_obj, f_name, soft_subtitles)]
|
140 |
+
name_tra = f"{sanitize_file_name(media_out_name)}.{subtitle_files}"
|
141 |
+
name_ori = f"{sanitize_file_name(base_name)}.{subtitle_files}"
|
142 |
+
tgt_subs = f"sub_tra.{subtitle_files}"
|
143 |
+
ori_subs = f"sub_ori.{subtitle_files}"
|
144 |
+
final_subtitles = [
|
145 |
+
get_output_file(tgt_subs, name_tra, False),
|
146 |
+
get_output_file(ori_subs, name_ori, False)
|
147 |
+
]
|
148 |
+
return final_media + final_subtitles
|
149 |
+
else:
|
150 |
+
return get_output_file(file_obj, f_name, soft_subtitles)
|
151 |
+
|
152 |
+
|
153 |
+
def get_subtitle_speaker(media_file, result, language, extension, base_name):
|
154 |
+
|
155 |
+
segments_base = copy.deepcopy(result)
|
156 |
+
|
157 |
+
# Sub segments by speaker
|
158 |
+
segments_by_speaker = {}
|
159 |
+
for segment in segments_base["segments"]:
|
160 |
+
if segment["speaker"] not in segments_by_speaker.keys():
|
161 |
+
segments_by_speaker[segment["speaker"]] = [segment]
|
162 |
+
else:
|
163 |
+
segments_by_speaker[segment["speaker"]].append(segment)
|
164 |
+
|
165 |
+
if not base_name:
|
166 |
+
if os.path.exists(media_file):
|
167 |
+
base_name = get_no_ext_filename(media_file)
|
168 |
+
else:
|
169 |
+
base_name, _ = get_video_info(media_file)
|
170 |
+
|
171 |
+
files_subs = []
|
172 |
+
for name_sk, segments in segments_by_speaker.items():
|
173 |
+
|
174 |
+
subtitle_speaker = get_subtitle(
|
175 |
+
language,
|
176 |
+
{"segments": segments},
|
177 |
+
extension,
|
178 |
+
filename=name_sk,
|
179 |
+
)
|
180 |
+
|
181 |
+
media_out_name = f"{base_name}_{language}_{name_sk}"
|
182 |
+
|
183 |
+
output = media_out(
|
184 |
+
media_file, # no need
|
185 |
+
language,
|
186 |
+
media_out_name,
|
187 |
+
extension,
|
188 |
+
file_obj=subtitle_speaker,
|
189 |
+
)
|
190 |
+
|
191 |
+
files_subs.append(output)
|
192 |
+
|
193 |
+
return files_subs
|
194 |
+
|
195 |
+
|
196 |
+
def sound_separate(media_file, task_uvr):
|
197 |
+
from .mdx_net import process_uvr_task
|
198 |
+
|
199 |
+
outputs = []
|
200 |
+
|
201 |
+
if "vocal" in task_uvr:
|
202 |
+
try:
|
203 |
+
_, _, _, _, vocal_audio = process_uvr_task(
|
204 |
+
orig_song_path=media_file,
|
205 |
+
main_vocals=False,
|
206 |
+
dereverb=True if "dereverb" in task_uvr else False,
|
207 |
+
remove_files_output_dir=True,
|
208 |
+
)
|
209 |
+
outputs.append(vocal_audio)
|
210 |
+
except Exception as error:
|
211 |
+
logger.error(str(error))
|
212 |
+
|
213 |
+
if "background" in task_uvr:
|
214 |
+
try:
|
215 |
+
background_audio, _ = process_uvr_task(
|
216 |
+
orig_song_path=media_file,
|
217 |
+
song_id="voiceless",
|
218 |
+
only_voiceless=True,
|
219 |
+
remove_files_output_dir=False if "vocal" in task_uvr else True,
|
220 |
+
)
|
221 |
+
# copy_files(background_audio, ".")
|
222 |
+
outputs.append(background_audio)
|
223 |
+
except Exception as error:
|
224 |
+
logger.error(str(error))
|
225 |
+
|
226 |
+
if not outputs:
|
227 |
+
raise Exception("Error in uvr process")
|
228 |
+
|
229 |
+
return outputs
|
soni_translate/preprocessor.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .utils import remove_files
|
2 |
+
import os, shutil, subprocess, time, shlex, sys # noqa
|
3 |
+
from .logging_setup import logger
|
4 |
+
import json
|
5 |
+
|
6 |
+
ERROR_INCORRECT_CODEC_PARAMETERS = [
|
7 |
+
"prores", # mov
|
8 |
+
"ffv1", # mkv
|
9 |
+
"msmpeg4v3", # avi
|
10 |
+
"wmv2", # wmv
|
11 |
+
"theora", # ogv
|
12 |
+
] # fix final merge
|
13 |
+
|
14 |
+
TESTED_CODECS = [
|
15 |
+
"h264", # mp4
|
16 |
+
"h265", # mp4
|
17 |
+
"vp9", # webm
|
18 |
+
"mpeg4", # mp4
|
19 |
+
"mpeg2video", # mpg
|
20 |
+
"mjpeg", # avi
|
21 |
+
]
|
22 |
+
|
23 |
+
|
24 |
+
class OperationFailedError(Exception):
|
25 |
+
def __init__(self, message="The operation did not complete successfully."):
|
26 |
+
self.message = message
|
27 |
+
super().__init__(self.message)
|
28 |
+
|
29 |
+
|
30 |
+
def get_video_codec(video_file):
|
31 |
+
command_base = rf'ffprobe -v error -select_streams v:0 -show_entries stream=codec_name -of json "{video_file}"'
|
32 |
+
command = shlex.split(command_base)
|
33 |
+
try:
|
34 |
+
process = subprocess.Popen(
|
35 |
+
command,
|
36 |
+
stdout=subprocess.PIPE,
|
37 |
+
creationflags=subprocess.CREATE_NO_WINDOW if sys.platform == "win32" else 0,
|
38 |
+
)
|
39 |
+
output, _ = process.communicate()
|
40 |
+
codec_info = json.loads(output.decode('utf-8'))
|
41 |
+
codec_name = codec_info['streams'][0]['codec_name']
|
42 |
+
return codec_name
|
43 |
+
except Exception as error:
|
44 |
+
logger.debug(str(error))
|
45 |
+
return None
|
46 |
+
|
47 |
+
|
48 |
+
def audio_preprocessor(preview, base_audio, audio_wav, use_cuda=False):
|
49 |
+
base_audio = base_audio.strip()
|
50 |
+
previous_files_to_remove = [audio_wav]
|
51 |
+
remove_files(previous_files_to_remove)
|
52 |
+
|
53 |
+
if preview:
|
54 |
+
logger.warning(
|
55 |
+
"Creating a preview video of 10 seconds, to disable "
|
56 |
+
"this option, go to advanced settings and turn off preview."
|
57 |
+
)
|
58 |
+
wav_ = f'ffmpeg -y -i "{base_audio}" -ss 00:00:20 -t 00:00:10 -vn -acodec pcm_s16le -ar 44100 -ac 2 audio.wav'
|
59 |
+
else:
|
60 |
+
wav_ = f'ffmpeg -y -i "{base_audio}" -vn -acodec pcm_s16le -ar 44100 -ac 2 audio.wav'
|
61 |
+
|
62 |
+
# Run cmd process
|
63 |
+
sub_params = {
|
64 |
+
"stdout": subprocess.PIPE,
|
65 |
+
"stderr": subprocess.PIPE,
|
66 |
+
"creationflags": subprocess.CREATE_NO_WINDOW
|
67 |
+
if sys.platform == "win32"
|
68 |
+
else 0,
|
69 |
+
}
|
70 |
+
wav_ = shlex.split(wav_)
|
71 |
+
result_convert_audio = subprocess.Popen(wav_, **sub_params)
|
72 |
+
output, errors = result_convert_audio.communicate()
|
73 |
+
time.sleep(1)
|
74 |
+
if result_convert_audio.returncode in [1, 2] or not os.path.exists(
|
75 |
+
audio_wav
|
76 |
+
):
|
77 |
+
raise OperationFailedError(f"Error can't create the audio file:\n{errors.decode('utf-8')}")
|
78 |
+
|
79 |
+
|
80 |
+
def audio_video_preprocessor(
|
81 |
+
preview, video, OutputFile, audio_wav, use_cuda=False
|
82 |
+
):
|
83 |
+
video = video.strip()
|
84 |
+
previous_files_to_remove = [OutputFile, "audio.webm", audio_wav]
|
85 |
+
remove_files(previous_files_to_remove)
|
86 |
+
|
87 |
+
if os.path.exists(video):
|
88 |
+
if preview:
|
89 |
+
logger.warning(
|
90 |
+
"Creating a preview video of 10 seconds, "
|
91 |
+
"to disable this option, go to advanced "
|
92 |
+
"settings and turn off preview."
|
93 |
+
)
|
94 |
+
mp4_ = f'ffmpeg -y -i "{video}" -ss 00:00:20 -t 00:00:10 -c:v libx264 -c:a aac -strict experimental Video.mp4'
|
95 |
+
else:
|
96 |
+
video_codec = get_video_codec(video)
|
97 |
+
if not video_codec:
|
98 |
+
logger.debug("No video codec found in video")
|
99 |
+
else:
|
100 |
+
logger.info(f"Video codec: {video_codec}")
|
101 |
+
|
102 |
+
# Check if the file ends with ".mp4" extension or is valid codec
|
103 |
+
if video.endswith(".mp4") or video_codec in TESTED_CODECS:
|
104 |
+
destination_path = os.path.join(os.getcwd(), "Video.mp4")
|
105 |
+
shutil.copy(video, destination_path)
|
106 |
+
time.sleep(0.5)
|
107 |
+
if os.path.exists(OutputFile):
|
108 |
+
mp4_ = "ffmpeg -h"
|
109 |
+
else:
|
110 |
+
mp4_ = f'ffmpeg -y -i "{video}" -c copy Video.mp4'
|
111 |
+
else:
|
112 |
+
logger.warning(
|
113 |
+
"File does not have the '.mp4' extension or a "
|
114 |
+
"supported codec. Converting video to mp4 (codec: h264)."
|
115 |
+
)
|
116 |
+
mp4_ = f'ffmpeg -y -i "{video}" -c:v libx264 -c:a aac -strict experimental Video.mp4'
|
117 |
+
else:
|
118 |
+
if preview:
|
119 |
+
logger.warning(
|
120 |
+
"Creating a preview from the link, 10 seconds "
|
121 |
+
"to disable this option, go to advanced "
|
122 |
+
"settings and turn off preview."
|
123 |
+
)
|
124 |
+
# https://github.com/yt-dlp/yt-dlp/issues/2220
|
125 |
+
mp4_ = f'yt-dlp -f "mp4" --downloader ffmpeg --downloader-args "ffmpeg_i: -ss 00:00:20 -t 00:00:10" --force-overwrites --max-downloads 1 --no-warnings --no-playlist --no-abort-on-error --ignore-no-formats-error --restrict-filenames -o {OutputFile} {video}'
|
126 |
+
wav_ = "ffmpeg -y -i Video.mp4 -vn -acodec pcm_s16le -ar 44100 -ac 2 audio.wav"
|
127 |
+
else:
|
128 |
+
mp4_ = f'yt-dlp -f "mp4" --force-overwrites --max-downloads 1 --no-warnings --no-playlist --no-abort-on-error --ignore-no-formats-error --restrict-filenames -o {OutputFile} {video}'
|
129 |
+
wav_ = f"python -m yt_dlp --output {audio_wav} --force-overwrites --max-downloads 1 --no-warnings --no-playlist --no-abort-on-error --ignore-no-formats-error --extract-audio --audio-format wav {video}"
|
130 |
+
|
131 |
+
# Run cmd process
|
132 |
+
mp4_ = shlex.split(mp4_)
|
133 |
+
sub_params = {
|
134 |
+
"stdout": subprocess.PIPE,
|
135 |
+
"stderr": subprocess.PIPE,
|
136 |
+
"creationflags": subprocess.CREATE_NO_WINDOW
|
137 |
+
if sys.platform == "win32"
|
138 |
+
else 0,
|
139 |
+
}
|
140 |
+
|
141 |
+
if os.path.exists(video):
|
142 |
+
logger.info("Process video...")
|
143 |
+
result_convert_video = subprocess.Popen(mp4_, **sub_params)
|
144 |
+
# result_convert_video.wait()
|
145 |
+
output, errors = result_convert_video.communicate()
|
146 |
+
time.sleep(1)
|
147 |
+
if result_convert_video.returncode in [1, 2] or not os.path.exists(
|
148 |
+
OutputFile
|
149 |
+
):
|
150 |
+
raise OperationFailedError(f"Error processing video:\n{errors.decode('utf-8')}")
|
151 |
+
logger.info("Process audio...")
|
152 |
+
wav_ = "ffmpeg -y -i Video.mp4 -vn -acodec pcm_s16le -ar 44100 -ac 2 audio.wav"
|
153 |
+
wav_ = shlex.split(wav_)
|
154 |
+
result_convert_audio = subprocess.Popen(wav_, **sub_params)
|
155 |
+
output, errors = result_convert_audio.communicate()
|
156 |
+
time.sleep(1)
|
157 |
+
if result_convert_audio.returncode in [1, 2] or not os.path.exists(
|
158 |
+
audio_wav
|
159 |
+
):
|
160 |
+
raise OperationFailedError(f"Error can't create the audio file:\n{errors.decode('utf-8')}")
|
161 |
+
|
162 |
+
else:
|
163 |
+
wav_ = shlex.split(wav_)
|
164 |
+
if preview:
|
165 |
+
result_convert_video = subprocess.Popen(mp4_, **sub_params)
|
166 |
+
output, errors = result_convert_video.communicate()
|
167 |
+
time.sleep(0.5)
|
168 |
+
result_convert_audio = subprocess.Popen(wav_, **sub_params)
|
169 |
+
output, errors = result_convert_audio.communicate()
|
170 |
+
time.sleep(0.5)
|
171 |
+
if result_convert_audio.returncode in [1, 2] or not os.path.exists(
|
172 |
+
audio_wav
|
173 |
+
):
|
174 |
+
raise OperationFailedError(
|
175 |
+
f"Error can't create the preview file:\n{errors.decode('utf-8')}"
|
176 |
+
)
|
177 |
+
else:
|
178 |
+
logger.info("Process audio...")
|
179 |
+
result_convert_audio = subprocess.Popen(wav_, **sub_params)
|
180 |
+
output, errors = result_convert_audio.communicate()
|
181 |
+
time.sleep(1)
|
182 |
+
if result_convert_audio.returncode in [1, 2] or not os.path.exists(
|
183 |
+
audio_wav
|
184 |
+
):
|
185 |
+
raise OperationFailedError(f"Error can't download the audio:\n{errors.decode('utf-8')}")
|
186 |
+
logger.info("Process video...")
|
187 |
+
result_convert_video = subprocess.Popen(mp4_, **sub_params)
|
188 |
+
output, errors = result_convert_video.communicate()
|
189 |
+
time.sleep(1)
|
190 |
+
if result_convert_video.returncode in [1, 2] or not os.path.exists(
|
191 |
+
OutputFile
|
192 |
+
):
|
193 |
+
raise OperationFailedError(f"Error can't download the video:\n{errors.decode('utf-8')}")
|
194 |
+
|
195 |
+
|
196 |
+
def old_audio_video_preprocessor(preview, video, OutputFile, audio_wav):
|
197 |
+
previous_files_to_remove = [OutputFile, "audio.webm", audio_wav]
|
198 |
+
remove_files(previous_files_to_remove)
|
199 |
+
|
200 |
+
if os.path.exists(video):
|
201 |
+
if preview:
|
202 |
+
logger.warning(
|
203 |
+
"Creating a preview video of 10 seconds, "
|
204 |
+
"to disable this option, go to advanced "
|
205 |
+
"settings and turn off preview."
|
206 |
+
)
|
207 |
+
command = f'ffmpeg -y -i "{video}" -ss 00:00:20 -t 00:00:10 -c:v libx264 -c:a aac -strict experimental Video.mp4'
|
208 |
+
result_convert_video = subprocess.run(
|
209 |
+
command, capture_output=True, text=True, shell=True
|
210 |
+
)
|
211 |
+
else:
|
212 |
+
# Check if the file ends with ".mp4" extension
|
213 |
+
if video.endswith(".mp4"):
|
214 |
+
destination_path = os.path.join(os.getcwd(), "Video.mp4")
|
215 |
+
shutil.copy(video, destination_path)
|
216 |
+
result_convert_video = {}
|
217 |
+
result_convert_video = subprocess.run(
|
218 |
+
"echo Video copied",
|
219 |
+
capture_output=True,
|
220 |
+
text=True,
|
221 |
+
shell=True,
|
222 |
+
)
|
223 |
+
else:
|
224 |
+
logger.warning(
|
225 |
+
"File does not have the '.mp4' extension. Converting video."
|
226 |
+
)
|
227 |
+
command = f'ffmpeg -y -i "{video}" -c:v libx264 -c:a aac -strict experimental Video.mp4'
|
228 |
+
result_convert_video = subprocess.run(
|
229 |
+
command, capture_output=True, text=True, shell=True
|
230 |
+
)
|
231 |
+
|
232 |
+
if result_convert_video.returncode in [1, 2]:
|
233 |
+
raise OperationFailedError("Error can't convert the video")
|
234 |
+
|
235 |
+
for i in range(120):
|
236 |
+
time.sleep(1)
|
237 |
+
logger.info("Process video...")
|
238 |
+
if os.path.exists(OutputFile):
|
239 |
+
time.sleep(1)
|
240 |
+
command = "ffmpeg -y -i Video.mp4 -vn -acodec pcm_s16le -ar 44100 -ac 2 audio.wav"
|
241 |
+
result_convert_audio = subprocess.run(
|
242 |
+
command, capture_output=True, text=True, shell=True
|
243 |
+
)
|
244 |
+
time.sleep(1)
|
245 |
+
break
|
246 |
+
if i == 119:
|
247 |
+
# if not os.path.exists(OutputFile):
|
248 |
+
raise OperationFailedError("Error processing video")
|
249 |
+
|
250 |
+
if result_convert_audio.returncode in [1, 2]:
|
251 |
+
raise OperationFailedError(
|
252 |
+
f"Error can't create the audio file: {result_convert_audio.stderr}"
|
253 |
+
)
|
254 |
+
|
255 |
+
for i in range(120):
|
256 |
+
time.sleep(1)
|
257 |
+
logger.info("Process audio...")
|
258 |
+
if os.path.exists(audio_wav):
|
259 |
+
break
|
260 |
+
if i == 119:
|
261 |
+
raise OperationFailedError("Error can't create the audio file")
|
262 |
+
|
263 |
+
else:
|
264 |
+
video = video.strip()
|
265 |
+
if preview:
|
266 |
+
logger.warning(
|
267 |
+
"Creating a preview from the link, 10 "
|
268 |
+
"seconds to disable this option, go to "
|
269 |
+
"advanced settings and turn off preview."
|
270 |
+
)
|
271 |
+
# https://github.com/yt-dlp/yt-dlp/issues/2220
|
272 |
+
mp4_ = f'yt-dlp -f "mp4" --downloader ffmpeg --downloader-args "ffmpeg_i: -ss 00:00:20 -t 00:00:10" --force-overwrites --max-downloads 1 --no-warnings --no-abort-on-error --ignore-no-formats-error --restrict-filenames -o {OutputFile} {video}'
|
273 |
+
wav_ = "ffmpeg -y -i Video.mp4 -vn -acodec pcm_s16le -ar 44100 -ac 2 audio.wav"
|
274 |
+
result_convert_video = subprocess.run(
|
275 |
+
mp4_, capture_output=True, text=True, shell=True
|
276 |
+
)
|
277 |
+
result_convert_audio = subprocess.run(
|
278 |
+
wav_, capture_output=True, text=True, shell=True
|
279 |
+
)
|
280 |
+
if result_convert_audio.returncode in [1, 2]:
|
281 |
+
raise OperationFailedError("Error can't download a preview")
|
282 |
+
else:
|
283 |
+
mp4_ = f'yt-dlp -f "mp4" --force-overwrites --max-downloads 1 --no-warnings --no-abort-on-error --ignore-no-formats-error --restrict-filenames -o {OutputFile} {video}'
|
284 |
+
wav_ = f"python -m yt_dlp --output {audio_wav} --force-overwrites --max-downloads 1 --no-warnings --no-abort-on-error --ignore-no-formats-error --extract-audio --audio-format wav {video}"
|
285 |
+
|
286 |
+
result_convert_audio = subprocess.run(
|
287 |
+
wav_, capture_output=True, text=True, shell=True
|
288 |
+
)
|
289 |
+
|
290 |
+
if result_convert_audio.returncode in [1, 2]:
|
291 |
+
raise OperationFailedError("Error can't download the audio")
|
292 |
+
|
293 |
+
for i in range(120):
|
294 |
+
time.sleep(1)
|
295 |
+
logger.info("Process audio...")
|
296 |
+
if os.path.exists(audio_wav) and not os.path.exists(
|
297 |
+
"audio.webm"
|
298 |
+
):
|
299 |
+
time.sleep(1)
|
300 |
+
result_convert_video = subprocess.run(
|
301 |
+
mp4_, capture_output=True, text=True, shell=True
|
302 |
+
)
|
303 |
+
break
|
304 |
+
if i == 119:
|
305 |
+
raise OperationFailedError("Error downloading the audio")
|
306 |
+
|
307 |
+
if result_convert_video.returncode in [1, 2]:
|
308 |
+
raise OperationFailedError("Error can't download the video")
|
soni_translate/speech_segmentation.py
ADDED
@@ -0,0 +1,499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from whisperx.alignment import (
|
2 |
+
DEFAULT_ALIGN_MODELS_TORCH as DAMT,
|
3 |
+
DEFAULT_ALIGN_MODELS_HF as DAMHF,
|
4 |
+
)
|
5 |
+
from whisperx.utils import TO_LANGUAGE_CODE
|
6 |
+
import whisperx
|
7 |
+
import torch
|
8 |
+
import gc
|
9 |
+
import os
|
10 |
+
import soundfile as sf
|
11 |
+
from IPython.utils import capture # noqa
|
12 |
+
from .language_configuration import EXTRA_ALIGN, INVERTED_LANGUAGES
|
13 |
+
from .logging_setup import logger
|
14 |
+
from .postprocessor import sanitize_file_name
|
15 |
+
from .utils import remove_directory_contents, run_command
|
16 |
+
|
17 |
+
# ZERO GPU CONFIG
|
18 |
+
import spaces
|
19 |
+
import copy
|
20 |
+
import random
|
21 |
+
import time
|
22 |
+
|
23 |
+
def random_sleep():
|
24 |
+
if os.environ.get("ZERO_GPU") == "TRUE":
|
25 |
+
print("Random sleep")
|
26 |
+
sleep_time = round(random.uniform(7.2, 9.9), 1)
|
27 |
+
time.sleep(sleep_time)
|
28 |
+
|
29 |
+
|
30 |
+
@spaces.GPU(duration=120)
|
31 |
+
def load_and_transcribe_audio(asr_model, audio, compute_type, language, asr_options, batch_size, segment_duration_limit):
|
32 |
+
# Load model
|
33 |
+
model = whisperx.load_model(
|
34 |
+
asr_model,
|
35 |
+
os.environ.get("SONITR_DEVICE") if os.environ.get("ZERO_GPU") != "TRUE" else "cuda",
|
36 |
+
compute_type=compute_type,
|
37 |
+
language=language,
|
38 |
+
asr_options=asr_options,
|
39 |
+
)
|
40 |
+
|
41 |
+
# Transcribe audio
|
42 |
+
result = model.transcribe(
|
43 |
+
audio,
|
44 |
+
batch_size=batch_size,
|
45 |
+
chunk_size=segment_duration_limit,
|
46 |
+
print_progress=True,
|
47 |
+
)
|
48 |
+
|
49 |
+
del model
|
50 |
+
gc.collect()
|
51 |
+
torch.cuda.empty_cache() # noqa
|
52 |
+
|
53 |
+
return result
|
54 |
+
|
55 |
+
def load_align_and_align_segments(result, audio, DAMHF):
|
56 |
+
|
57 |
+
# Load alignment model
|
58 |
+
model_a, metadata = whisperx.load_align_model(
|
59 |
+
language_code=result["language"],
|
60 |
+
device=os.environ.get("SONITR_DEVICE") if os.environ.get("ZERO_GPU") != "TRUE" else "cuda",
|
61 |
+
model_name=None
|
62 |
+
if result["language"] in DAMHF.keys()
|
63 |
+
else EXTRA_ALIGN[result["language"]],
|
64 |
+
)
|
65 |
+
|
66 |
+
# Align segments
|
67 |
+
alignment_result = whisperx.align(
|
68 |
+
result["segments"],
|
69 |
+
model_a,
|
70 |
+
metadata,
|
71 |
+
audio,
|
72 |
+
os.environ.get("SONITR_DEVICE") if os.environ.get("ZERO_GPU") != "TRUE" else "cuda",
|
73 |
+
return_char_alignments=True,
|
74 |
+
print_progress=False,
|
75 |
+
)
|
76 |
+
|
77 |
+
# Clean up
|
78 |
+
del model_a
|
79 |
+
gc.collect()
|
80 |
+
torch.cuda.empty_cache() # noqa
|
81 |
+
|
82 |
+
return alignment_result
|
83 |
+
|
84 |
+
@spaces.GPU(duration=120)
|
85 |
+
def diarize_audio(diarize_model, audio_wav, min_speakers, max_speakers):
|
86 |
+
|
87 |
+
if os.environ.get("ZERO_GPU") == "TRUE":
|
88 |
+
diarize_model.model.to(torch.device("cuda"))
|
89 |
+
diarize_segments = diarize_model(
|
90 |
+
audio_wav,
|
91 |
+
min_speakers=min_speakers,
|
92 |
+
max_speakers=max_speakers
|
93 |
+
)
|
94 |
+
return diarize_segments
|
95 |
+
|
96 |
+
# ZERO GPU CONFIG
|
97 |
+
|
98 |
+
ASR_MODEL_OPTIONS = [
|
99 |
+
"tiny",
|
100 |
+
"base",
|
101 |
+
"small",
|
102 |
+
"medium",
|
103 |
+
"large",
|
104 |
+
"large-v1",
|
105 |
+
"large-v2",
|
106 |
+
"large-v3",
|
107 |
+
"distil-large-v2",
|
108 |
+
"Systran/faster-distil-whisper-large-v3",
|
109 |
+
"tiny.en",
|
110 |
+
"base.en",
|
111 |
+
"small.en",
|
112 |
+
"medium.en",
|
113 |
+
"distil-small.en",
|
114 |
+
"distil-medium.en",
|
115 |
+
"OpenAI_API_Whisper",
|
116 |
+
]
|
117 |
+
|
118 |
+
COMPUTE_TYPE_GPU = [
|
119 |
+
"default",
|
120 |
+
"auto",
|
121 |
+
"int8",
|
122 |
+
"int8_float32",
|
123 |
+
"int8_float16",
|
124 |
+
"int8_bfloat16",
|
125 |
+
"float16",
|
126 |
+
"bfloat16",
|
127 |
+
"float32"
|
128 |
+
]
|
129 |
+
|
130 |
+
COMPUTE_TYPE_CPU = [
|
131 |
+
"default",
|
132 |
+
"auto",
|
133 |
+
"int8",
|
134 |
+
"int8_float32",
|
135 |
+
"int16",
|
136 |
+
"float32",
|
137 |
+
]
|
138 |
+
|
139 |
+
WHISPER_MODELS_PATH = './WHISPER_MODELS'
|
140 |
+
|
141 |
+
|
142 |
+
def openai_api_whisper(
|
143 |
+
input_audio_file,
|
144 |
+
source_lang=None,
|
145 |
+
chunk_duration=1800
|
146 |
+
):
|
147 |
+
|
148 |
+
info = sf.info(input_audio_file)
|
149 |
+
duration = info.duration
|
150 |
+
|
151 |
+
output_directory = "./whisper_api_audio_parts"
|
152 |
+
os.makedirs(output_directory, exist_ok=True)
|
153 |
+
remove_directory_contents(output_directory)
|
154 |
+
|
155 |
+
if duration > chunk_duration:
|
156 |
+
# Split the audio file into smaller chunks with 30-minute duration
|
157 |
+
cm = f'ffmpeg -i "{input_audio_file}" -f segment -segment_time {chunk_duration} -c:a libvorbis "{output_directory}/output%03d.ogg"'
|
158 |
+
run_command(cm)
|
159 |
+
# Get list of generated chunk files
|
160 |
+
chunk_files = sorted(
|
161 |
+
[f"{output_directory}/{f}" for f in os.listdir(output_directory) if f.endswith('.ogg')]
|
162 |
+
)
|
163 |
+
else:
|
164 |
+
one_file = f"{output_directory}/output000.ogg"
|
165 |
+
cm = f'ffmpeg -i "{input_audio_file}" -c:a libvorbis {one_file}'
|
166 |
+
run_command(cm)
|
167 |
+
chunk_files = [one_file]
|
168 |
+
|
169 |
+
# Transcript
|
170 |
+
segments = []
|
171 |
+
language = source_lang if source_lang else None
|
172 |
+
for i, chunk in enumerate(chunk_files):
|
173 |
+
from openai import OpenAI
|
174 |
+
client = OpenAI()
|
175 |
+
|
176 |
+
audio_file = open(chunk, "rb")
|
177 |
+
transcription = client.audio.transcriptions.create(
|
178 |
+
model="whisper-1",
|
179 |
+
file=audio_file,
|
180 |
+
language=language,
|
181 |
+
response_format="verbose_json",
|
182 |
+
timestamp_granularities=["segment"],
|
183 |
+
)
|
184 |
+
|
185 |
+
try:
|
186 |
+
transcript_dict = transcription.model_dump()
|
187 |
+
except: # noqa
|
188 |
+
transcript_dict = transcription.to_dict()
|
189 |
+
|
190 |
+
if language is None:
|
191 |
+
logger.info(f'Language detected: {transcript_dict["language"]}')
|
192 |
+
language = TO_LANGUAGE_CODE[transcript_dict["language"]]
|
193 |
+
|
194 |
+
chunk_time = chunk_duration * (i)
|
195 |
+
|
196 |
+
for seg in transcript_dict["segments"]:
|
197 |
+
|
198 |
+
if "start" in seg.keys():
|
199 |
+
segments.append(
|
200 |
+
{
|
201 |
+
"text": seg["text"],
|
202 |
+
"start": seg["start"] + chunk_time,
|
203 |
+
"end": seg["end"] + chunk_time,
|
204 |
+
}
|
205 |
+
)
|
206 |
+
|
207 |
+
audio = whisperx.load_audio(input_audio_file)
|
208 |
+
result = {"segments": segments, "language": language}
|
209 |
+
|
210 |
+
return audio, result
|
211 |
+
|
212 |
+
|
213 |
+
def find_whisper_models():
|
214 |
+
path = WHISPER_MODELS_PATH
|
215 |
+
folders = []
|
216 |
+
|
217 |
+
if os.path.exists(path):
|
218 |
+
for folder in os.listdir(path):
|
219 |
+
folder_path = os.path.join(path, folder)
|
220 |
+
if (
|
221 |
+
os.path.isdir(folder_path)
|
222 |
+
and 'model.bin' in os.listdir(folder_path)
|
223 |
+
):
|
224 |
+
folders.append(folder)
|
225 |
+
return folders
|
226 |
+
|
227 |
+
def transcribe_speech(
|
228 |
+
audio_wav,
|
229 |
+
asr_model,
|
230 |
+
compute_type,
|
231 |
+
batch_size,
|
232 |
+
SOURCE_LANGUAGE,
|
233 |
+
literalize_numbers=True,
|
234 |
+
segment_duration_limit=15,
|
235 |
+
):
|
236 |
+
"""
|
237 |
+
Transcribe speech using a whisper model.
|
238 |
+
|
239 |
+
Parameters:
|
240 |
+
- audio_wav (str): Path to the audio file in WAV format.
|
241 |
+
- asr_model (str): The whisper model to be loaded.
|
242 |
+
- compute_type (str): Type of compute to be used (e.g., 'int8', 'float16').
|
243 |
+
- batch_size (int): Batch size for transcription.
|
244 |
+
- SOURCE_LANGUAGE (str): Source language for transcription.
|
245 |
+
|
246 |
+
Returns:
|
247 |
+
- Tuple containing:
|
248 |
+
- audio: Loaded audio file.
|
249 |
+
- result: Transcription result as a dictionary.
|
250 |
+
"""
|
251 |
+
|
252 |
+
if asr_model == "OpenAI_API_Whisper":
|
253 |
+
if literalize_numbers:
|
254 |
+
logger.info(
|
255 |
+
"OpenAI's API Whisper does not support "
|
256 |
+
"the literalization of numbers."
|
257 |
+
)
|
258 |
+
return openai_api_whisper(audio_wav, SOURCE_LANGUAGE)
|
259 |
+
|
260 |
+
# https://github.com/openai/whisper/discussions/277
|
261 |
+
prompt = "以下是普通话的句子。" if SOURCE_LANGUAGE == "zh" else None
|
262 |
+
SOURCE_LANGUAGE = (
|
263 |
+
SOURCE_LANGUAGE if SOURCE_LANGUAGE != "zh-TW" else "zh"
|
264 |
+
)
|
265 |
+
asr_options = {
|
266 |
+
"initial_prompt": prompt,
|
267 |
+
"suppress_numerals": literalize_numbers
|
268 |
+
}
|
269 |
+
|
270 |
+
if asr_model not in ASR_MODEL_OPTIONS:
|
271 |
+
|
272 |
+
base_dir = WHISPER_MODELS_PATH
|
273 |
+
if not os.path.exists(base_dir):
|
274 |
+
os.makedirs(base_dir)
|
275 |
+
model_dir = os.path.join(base_dir, sanitize_file_name(asr_model))
|
276 |
+
|
277 |
+
if not os.path.exists(model_dir):
|
278 |
+
from ctranslate2.converters import TransformersConverter
|
279 |
+
|
280 |
+
quantization = "float32"
|
281 |
+
# Download new model
|
282 |
+
try:
|
283 |
+
converter = TransformersConverter(
|
284 |
+
asr_model,
|
285 |
+
low_cpu_mem_usage=True,
|
286 |
+
copy_files=[
|
287 |
+
"tokenizer_config.json", "preprocessor_config.json"
|
288 |
+
]
|
289 |
+
)
|
290 |
+
converter.convert(
|
291 |
+
model_dir,
|
292 |
+
quantization=quantization,
|
293 |
+
force=False
|
294 |
+
)
|
295 |
+
except Exception as error:
|
296 |
+
if "File tokenizer_config.json does not exist" in str(error):
|
297 |
+
converter._copy_files = [
|
298 |
+
"tokenizer.json", "preprocessor_config.json"
|
299 |
+
]
|
300 |
+
converter.convert(
|
301 |
+
model_dir,
|
302 |
+
quantization=quantization,
|
303 |
+
force=True
|
304 |
+
)
|
305 |
+
else:
|
306 |
+
raise error
|
307 |
+
|
308 |
+
asr_model = model_dir
|
309 |
+
logger.info(f"ASR Model: {str(model_dir)}")
|
310 |
+
|
311 |
+
audio = whisperx.load_audio(audio_wav)
|
312 |
+
|
313 |
+
result = load_and_transcribe_audio(
|
314 |
+
asr_model, audio, compute_type, SOURCE_LANGUAGE, asr_options, batch_size, segment_duration_limit
|
315 |
+
)
|
316 |
+
|
317 |
+
if result["language"] == "zh" and not prompt:
|
318 |
+
result["language"] = "zh-TW"
|
319 |
+
logger.info("Chinese - Traditional (zh-TW)")
|
320 |
+
|
321 |
+
|
322 |
+
return audio, result
|
323 |
+
|
324 |
+
|
325 |
+
def align_speech(audio, result):
|
326 |
+
"""
|
327 |
+
Aligns speech segments based on the provided audio and result metadata.
|
328 |
+
|
329 |
+
Parameters:
|
330 |
+
- audio (array): The audio data in a suitable format for alignment.
|
331 |
+
- result (dict): Metadata containing information about the segments
|
332 |
+
and language.
|
333 |
+
|
334 |
+
Returns:
|
335 |
+
- result (dict): Updated metadata after aligning the segments with
|
336 |
+
the audio. This includes character-level alignments if
|
337 |
+
'return_char_alignments' is set to True.
|
338 |
+
|
339 |
+
Notes:
|
340 |
+
- This function uses language-specific models to align speech segments.
|
341 |
+
- It performs language compatibility checks and selects the
|
342 |
+
appropriate alignment model.
|
343 |
+
- Cleans up memory by releasing resources after alignment.
|
344 |
+
"""
|
345 |
+
DAMHF.update(DAMT) # lang align
|
346 |
+
if (
|
347 |
+
not result["language"] in DAMHF.keys()
|
348 |
+
and not result["language"] in EXTRA_ALIGN.keys()
|
349 |
+
):
|
350 |
+
logger.warning(
|
351 |
+
"Automatic detection: Source language not compatible with align"
|
352 |
+
)
|
353 |
+
raise ValueError(
|
354 |
+
f"Detected language {result['language']} incompatible, "
|
355 |
+
"you can select the source language to avoid this error."
|
356 |
+
)
|
357 |
+
if (
|
358 |
+
result["language"] in EXTRA_ALIGN.keys()
|
359 |
+
and EXTRA_ALIGN[result["language"]] == ""
|
360 |
+
):
|
361 |
+
lang_name = (
|
362 |
+
INVERTED_LANGUAGES[result["language"]]
|
363 |
+
if result["language"] in INVERTED_LANGUAGES.keys()
|
364 |
+
else result["language"]
|
365 |
+
)
|
366 |
+
logger.warning(
|
367 |
+
"No compatible wav2vec2 model found "
|
368 |
+
f"for the language '{lang_name}', skipping alignment."
|
369 |
+
)
|
370 |
+
return result
|
371 |
+
|
372 |
+
random_sleep()
|
373 |
+
result = load_align_and_align_segments(result, audio, DAMHF)
|
374 |
+
|
375 |
+
return result
|
376 |
+
|
377 |
+
|
378 |
+
diarization_models = {
|
379 |
+
"pyannote_3.1": "pyannote/speaker-diarization-3.1",
|
380 |
+
"pyannote_2.1": "pyannote/speaker-diarization@2.1",
|
381 |
+
"disable": "",
|
382 |
+
}
|
383 |
+
|
384 |
+
|
385 |
+
def reencode_speakers(result):
|
386 |
+
|
387 |
+
if result["segments"][0]["speaker"] == "SPEAKER_00":
|
388 |
+
return result
|
389 |
+
|
390 |
+
speaker_mapping = {}
|
391 |
+
counter = 0
|
392 |
+
|
393 |
+
logger.debug("Reencode speakers")
|
394 |
+
|
395 |
+
for segment in result["segments"]:
|
396 |
+
old_speaker = segment["speaker"]
|
397 |
+
if old_speaker not in speaker_mapping:
|
398 |
+
speaker_mapping[old_speaker] = f"SPEAKER_{counter:02d}"
|
399 |
+
counter += 1
|
400 |
+
segment["speaker"] = speaker_mapping[old_speaker]
|
401 |
+
|
402 |
+
return result
|
403 |
+
|
404 |
+
|
405 |
+
def diarize_speech(
|
406 |
+
audio_wav,
|
407 |
+
result,
|
408 |
+
min_speakers,
|
409 |
+
max_speakers,
|
410 |
+
YOUR_HF_TOKEN,
|
411 |
+
model_name="pyannote/speaker-diarization@2.1",
|
412 |
+
):
|
413 |
+
"""
|
414 |
+
Performs speaker diarization on speech segments.
|
415 |
+
|
416 |
+
Parameters:
|
417 |
+
- audio_wav (array): Audio data in WAV format to perform speaker
|
418 |
+
diarization.
|
419 |
+
- result (dict): Metadata containing information about speech segments
|
420 |
+
and alignments.
|
421 |
+
- min_speakers (int): Minimum number of speakers expected in the audio.
|
422 |
+
- max_speakers (int): Maximum number of speakers expected in the audio.
|
423 |
+
- YOUR_HF_TOKEN (str): Your Hugging Face API token for model
|
424 |
+
authentication.
|
425 |
+
- model_name (str): Name of the speaker diarization model to be used
|
426 |
+
(default: "pyannote/speaker-diarization@2.1").
|
427 |
+
|
428 |
+
Returns:
|
429 |
+
- result_diarize (dict): Updated metadata after assigning speaker
|
430 |
+
labels to segments.
|
431 |
+
|
432 |
+
Notes:
|
433 |
+
- This function utilizes a speaker diarization model to label speaker
|
434 |
+
segments in the audio.
|
435 |
+
- It assigns speakers to word-level segments based on diarization results.
|
436 |
+
- Cleans up memory by releasing resources after diarization.
|
437 |
+
- If only one speaker is specified, each segment is automatically assigned
|
438 |
+
as the first speaker, eliminating the need for diarization inference.
|
439 |
+
"""
|
440 |
+
|
441 |
+
if max(min_speakers, max_speakers) > 1 and model_name:
|
442 |
+
try:
|
443 |
+
|
444 |
+
diarize_model = whisperx.DiarizationPipeline(
|
445 |
+
model_name=model_name,
|
446 |
+
use_auth_token=YOUR_HF_TOKEN,
|
447 |
+
device=os.environ.get("SONITR_DEVICE"),
|
448 |
+
)
|
449 |
+
|
450 |
+
except Exception as error:
|
451 |
+
error_str = str(error)
|
452 |
+
gc.collect()
|
453 |
+
torch.cuda.empty_cache() # noqa
|
454 |
+
if "'NoneType' object has no attribute 'to'" in error_str:
|
455 |
+
if model_name == diarization_models["pyannote_2.1"]:
|
456 |
+
raise ValueError(
|
457 |
+
"Accept the license agreement for using Pyannote 2.1."
|
458 |
+
" You need to have an account on Hugging Face and "
|
459 |
+
"accept the license to use the models: "
|
460 |
+
"https://huggingface.co/pyannote/speaker-diarization "
|
461 |
+
"and https://huggingface.co/pyannote/segmentation "
|
462 |
+
"Get your KEY TOKEN here: "
|
463 |
+
"https://hf.co/settings/tokens "
|
464 |
+
)
|
465 |
+
elif model_name == diarization_models["pyannote_3.1"]:
|
466 |
+
raise ValueError(
|
467 |
+
"New Licence Pyannote 3.1: You need to have an account"
|
468 |
+
" on Hugging Face and accept the license to use the "
|
469 |
+
"models: https://huggingface.co/pyannote/speaker-diarization-3.1 " # noqa
|
470 |
+
"and https://huggingface.co/pyannote/segmentation-3.0 "
|
471 |
+
)
|
472 |
+
else:
|
473 |
+
raise error
|
474 |
+
|
475 |
+
random_sleep()
|
476 |
+
diarize_segments = diarize_audio(diarize_model, audio_wav, min_speakers, max_speakers)
|
477 |
+
|
478 |
+
result_diarize = whisperx.assign_word_speakers(
|
479 |
+
diarize_segments, result
|
480 |
+
)
|
481 |
+
|
482 |
+
for segment in result_diarize["segments"]:
|
483 |
+
if "speaker" not in segment:
|
484 |
+
segment["speaker"] = "SPEAKER_00"
|
485 |
+
logger.warning(
|
486 |
+
f"No speaker detected in {segment['start']}. First TTS "
|
487 |
+
f"will be used for the segment text: {segment['text']} "
|
488 |
+
)
|
489 |
+
|
490 |
+
del diarize_model
|
491 |
+
gc.collect()
|
492 |
+
torch.cuda.empty_cache() # noqa
|
493 |
+
else:
|
494 |
+
result_diarize = result
|
495 |
+
result_diarize["segments"] = [
|
496 |
+
{**item, "speaker": "SPEAKER_00"}
|
497 |
+
for item in result_diarize["segments"]
|
498 |
+
]
|
499 |
+
return reencode_speakers(result_diarize)
|
soni_translate/text_multiformat_processor.py
ADDED
@@ -0,0 +1,987 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .logging_setup import logger
|
2 |
+
from whisperx.utils import get_writer
|
3 |
+
from .utils import remove_files, run_command, remove_directory_contents
|
4 |
+
from typing import List
|
5 |
+
import srt
|
6 |
+
import re
|
7 |
+
import os
|
8 |
+
import copy
|
9 |
+
import string
|
10 |
+
import soundfile as sf
|
11 |
+
from PIL import Image, ImageOps, ImageDraw, ImageFont
|
12 |
+
|
13 |
+
punctuation_list = list(
|
14 |
+
string.punctuation + "¡¿«»„”“”‚‘’「」『』《》()【】〈〉〔〕〖〗〘〙〚〛⸤⸥⸨⸩"
|
15 |
+
)
|
16 |
+
symbol_list = punctuation_list + ["", "..", "..."]
|
17 |
+
|
18 |
+
|
19 |
+
def extract_from_srt(file_path):
|
20 |
+
with open(file_path, "r", encoding="utf-8") as file:
|
21 |
+
srt_content = file.read()
|
22 |
+
|
23 |
+
subtitle_generator = srt.parse(srt_content)
|
24 |
+
srt_content_list = list(subtitle_generator)
|
25 |
+
|
26 |
+
return srt_content_list
|
27 |
+
|
28 |
+
|
29 |
+
def clean_text(text):
|
30 |
+
|
31 |
+
# Remove content within square brackets
|
32 |
+
text = re.sub(r'\[.*?\]', '', text)
|
33 |
+
# Add pattern to remove content within <comment> tags
|
34 |
+
text = re.sub(r'<comment>.*?</comment>', '', text)
|
35 |
+
# Remove HTML tags
|
36 |
+
text = re.sub(r'<.*?>', '', text)
|
37 |
+
# Remove "♫" and "♪" content
|
38 |
+
text = re.sub(r'♫.*?♫', '', text)
|
39 |
+
text = re.sub(r'♪.*?♪', '', text)
|
40 |
+
# Replace newline characters with an empty string
|
41 |
+
text = text.replace("\n", ". ")
|
42 |
+
# Remove double quotation marks
|
43 |
+
text = text.replace('"', '')
|
44 |
+
# Collapse multiple spaces and replace with a single space
|
45 |
+
text = re.sub(r"\s+", " ", text)
|
46 |
+
# Normalize spaces around periods
|
47 |
+
text = re.sub(r"[\s\.]+(?=\s)", ". ", text)
|
48 |
+
# Check if there are ♫ or ♪ symbols present
|
49 |
+
if '♫' in text or '♪' in text:
|
50 |
+
return ""
|
51 |
+
|
52 |
+
text = text.strip()
|
53 |
+
|
54 |
+
# Valid text
|
55 |
+
return text if text not in symbol_list else ""
|
56 |
+
|
57 |
+
|
58 |
+
def srt_file_to_segments(file_path, speaker=False):
|
59 |
+
try:
|
60 |
+
srt_content_list = extract_from_srt(file_path)
|
61 |
+
except Exception as error:
|
62 |
+
logger.error(str(error))
|
63 |
+
fixed_file = "fixed_sub.srt"
|
64 |
+
remove_files(fixed_file)
|
65 |
+
fix_sub = f'ffmpeg -i "{file_path}" "{fixed_file}" -y'
|
66 |
+
run_command(fix_sub)
|
67 |
+
srt_content_list = extract_from_srt(fixed_file)
|
68 |
+
|
69 |
+
segments = []
|
70 |
+
for segment in srt_content_list:
|
71 |
+
|
72 |
+
text = clean_text(str(segment.content))
|
73 |
+
|
74 |
+
if text:
|
75 |
+
segments.append(
|
76 |
+
{
|
77 |
+
"text": text,
|
78 |
+
"start": float(segment.start.total_seconds()),
|
79 |
+
"end": float(segment.end.total_seconds()),
|
80 |
+
}
|
81 |
+
)
|
82 |
+
|
83 |
+
if not segments:
|
84 |
+
raise Exception("No data found in srt subtitle file")
|
85 |
+
|
86 |
+
if speaker:
|
87 |
+
segments = [{**seg, "speaker": "SPEAKER_00"} for seg in segments]
|
88 |
+
|
89 |
+
return {"segments": segments}
|
90 |
+
|
91 |
+
|
92 |
+
# documents
|
93 |
+
|
94 |
+
|
95 |
+
def dehyphenate(lines: List[str], line_no: int) -> List[str]:
|
96 |
+
next_line = lines[line_no + 1]
|
97 |
+
word_suffix = next_line.split(" ")[0]
|
98 |
+
|
99 |
+
lines[line_no] = lines[line_no][:-1] + word_suffix
|
100 |
+
lines[line_no + 1] = lines[line_no + 1][len(word_suffix):]
|
101 |
+
return lines
|
102 |
+
|
103 |
+
|
104 |
+
def remove_hyphens(text: str) -> str:
|
105 |
+
"""
|
106 |
+
|
107 |
+
This fails for:
|
108 |
+
* Natural dashes: well-known, self-replication, use-cases, non-semantic,
|
109 |
+
Post-processing, Window-wise, viewpoint-dependent
|
110 |
+
* Trailing math operands: 2 - 4
|
111 |
+
* Names: Lopez-Ferreras, VGG-19, CIFAR-100
|
112 |
+
"""
|
113 |
+
lines = [line.rstrip() for line in text.split("\n")]
|
114 |
+
|
115 |
+
# Find dashes
|
116 |
+
line_numbers = []
|
117 |
+
for line_no, line in enumerate(lines[:-1]):
|
118 |
+
if line.endswith("-"):
|
119 |
+
line_numbers.append(line_no)
|
120 |
+
|
121 |
+
# Replace
|
122 |
+
for line_no in line_numbers:
|
123 |
+
lines = dehyphenate(lines, line_no)
|
124 |
+
|
125 |
+
return "\n".join(lines)
|
126 |
+
|
127 |
+
|
128 |
+
def pdf_to_txt(pdf_file, start_page, end_page):
|
129 |
+
from pypdf import PdfReader
|
130 |
+
|
131 |
+
with open(pdf_file, "rb") as file:
|
132 |
+
reader = PdfReader(file)
|
133 |
+
logger.debug(f"Total pages: {reader.get_num_pages()}")
|
134 |
+
text = ""
|
135 |
+
|
136 |
+
start_page_idx = max((start_page-1), 0)
|
137 |
+
end_page_inx = min((end_page), (reader.get_num_pages()))
|
138 |
+
document_pages = reader.pages[start_page_idx:end_page_inx]
|
139 |
+
logger.info(
|
140 |
+
f"Selected pages from {start_page_idx} to {end_page_inx}: "
|
141 |
+
f"{len(document_pages)}"
|
142 |
+
)
|
143 |
+
|
144 |
+
for page in document_pages:
|
145 |
+
text += remove_hyphens(page.extract_text())
|
146 |
+
return text
|
147 |
+
|
148 |
+
|
149 |
+
def docx_to_txt(docx_file):
|
150 |
+
# https://github.com/AlJohri/docx2pdf update
|
151 |
+
from docx import Document
|
152 |
+
|
153 |
+
doc = Document(docx_file)
|
154 |
+
text = ""
|
155 |
+
for paragraph in doc.paragraphs:
|
156 |
+
text += paragraph.text + "\n"
|
157 |
+
return text
|
158 |
+
|
159 |
+
|
160 |
+
def replace_multiple_elements(text, replacements):
|
161 |
+
pattern = re.compile("|".join(map(re.escape, replacements.keys())))
|
162 |
+
replaced_text = pattern.sub(
|
163 |
+
lambda match: replacements[match.group(0)], text
|
164 |
+
)
|
165 |
+
|
166 |
+
# Remove multiple spaces
|
167 |
+
replaced_text = re.sub(r"\s+", " ", replaced_text)
|
168 |
+
|
169 |
+
return replaced_text
|
170 |
+
|
171 |
+
|
172 |
+
def document_preprocessor(file_path, is_string, start_page, end_page):
|
173 |
+
if not is_string:
|
174 |
+
file_ext = os.path.splitext(file_path)[1].lower()
|
175 |
+
|
176 |
+
if is_string:
|
177 |
+
text = file_path
|
178 |
+
elif file_ext == ".pdf":
|
179 |
+
text = pdf_to_txt(file_path, start_page, end_page)
|
180 |
+
elif file_ext == ".docx":
|
181 |
+
text = docx_to_txt(file_path)
|
182 |
+
elif file_ext == ".txt":
|
183 |
+
with open(
|
184 |
+
file_path, "r", encoding='utf-8', errors='replace'
|
185 |
+
) as file:
|
186 |
+
text = file.read()
|
187 |
+
else:
|
188 |
+
raise Exception("Unsupported file format")
|
189 |
+
|
190 |
+
# Add space to break segments more easily later
|
191 |
+
replacements = {
|
192 |
+
"、": "、 ",
|
193 |
+
"。": "。 ",
|
194 |
+
# "\n": " ",
|
195 |
+
}
|
196 |
+
text = replace_multiple_elements(text, replacements)
|
197 |
+
|
198 |
+
# Save text to a .txt file
|
199 |
+
# file_name = os.path.splitext(os.path.basename(file_path))[0]
|
200 |
+
txt_file_path = "./text_preprocessor.txt"
|
201 |
+
|
202 |
+
with open(
|
203 |
+
txt_file_path, "w", encoding='utf-8', errors='replace'
|
204 |
+
) as txt_file:
|
205 |
+
txt_file.write(text)
|
206 |
+
|
207 |
+
return txt_file_path, text
|
208 |
+
|
209 |
+
|
210 |
+
def split_text_into_chunks(text, chunk_size):
|
211 |
+
words = re.findall(r"\b\w+\b", text)
|
212 |
+
chunks = []
|
213 |
+
current_chunk = ""
|
214 |
+
for word in words:
|
215 |
+
if (
|
216 |
+
len(current_chunk) + len(word) + 1 <= chunk_size
|
217 |
+
): # Adding 1 for the space between words
|
218 |
+
if current_chunk:
|
219 |
+
current_chunk += " "
|
220 |
+
current_chunk += word
|
221 |
+
else:
|
222 |
+
chunks.append(current_chunk)
|
223 |
+
current_chunk = word
|
224 |
+
if current_chunk:
|
225 |
+
chunks.append(current_chunk)
|
226 |
+
return chunks
|
227 |
+
|
228 |
+
|
229 |
+
def determine_chunk_size(file_name):
|
230 |
+
patterns = {
|
231 |
+
re.compile(r".*-(Male|Female)$"): 1024, # by character
|
232 |
+
re.compile(r".* BARK$"): 100, # t 64 256
|
233 |
+
re.compile(r".* VITS$"): 500,
|
234 |
+
re.compile(
|
235 |
+
r".+\.(wav|mp3|ogg|m4a)$"
|
236 |
+
): 150, # t 250 400 api automatic split
|
237 |
+
re.compile(r".* VITS-onnx$"): 250, # automatic sentence split
|
238 |
+
re.compile(r".* OpenAI-TTS$"): 1024 # max charaters 4096
|
239 |
+
}
|
240 |
+
|
241 |
+
for pattern, chunk_size in patterns.items():
|
242 |
+
if pattern.match(file_name):
|
243 |
+
return chunk_size
|
244 |
+
|
245 |
+
# Default chunk size if the file doesn't match any pattern; max 1800
|
246 |
+
return 100
|
247 |
+
|
248 |
+
|
249 |
+
def plain_text_to_segments(result_text=None, chunk_size=None):
|
250 |
+
if not chunk_size:
|
251 |
+
chunk_size = 100
|
252 |
+
text_chunks = split_text_into_chunks(result_text, chunk_size)
|
253 |
+
|
254 |
+
segments_chunks = []
|
255 |
+
for num, chunk in enumerate(text_chunks):
|
256 |
+
chunk_dict = {
|
257 |
+
"text": chunk,
|
258 |
+
"start": (1.0 + num),
|
259 |
+
"end": (2.0 + num),
|
260 |
+
"speaker": "SPEAKER_00",
|
261 |
+
}
|
262 |
+
segments_chunks.append(chunk_dict)
|
263 |
+
|
264 |
+
result_diarize = {"segments": segments_chunks}
|
265 |
+
|
266 |
+
return result_diarize
|
267 |
+
|
268 |
+
|
269 |
+
def segments_to_plain_text(result_diarize):
|
270 |
+
complete_text = ""
|
271 |
+
for seg in result_diarize["segments"]:
|
272 |
+
complete_text += seg["text"] + " " # issue
|
273 |
+
|
274 |
+
# Save text to a .txt file
|
275 |
+
# file_name = os.path.splitext(os.path.basename(file_path))[0]
|
276 |
+
txt_file_path = "./text_translation.txt"
|
277 |
+
|
278 |
+
with open(
|
279 |
+
txt_file_path, "w", encoding='utf-8', errors='replace'
|
280 |
+
) as txt_file:
|
281 |
+
txt_file.write(complete_text)
|
282 |
+
|
283 |
+
return txt_file_path, complete_text
|
284 |
+
|
285 |
+
|
286 |
+
# doc to video
|
287 |
+
|
288 |
+
COLORS = {
|
289 |
+
"black": (0, 0, 0),
|
290 |
+
"white": (255, 255, 255),
|
291 |
+
"red": (255, 0, 0),
|
292 |
+
"green": (0, 255, 0),
|
293 |
+
"blue": (0, 0, 255),
|
294 |
+
"yellow": (255, 255, 0),
|
295 |
+
"light_gray": (200, 200, 200),
|
296 |
+
"light_blue": (173, 216, 230),
|
297 |
+
"light_green": (144, 238, 144),
|
298 |
+
"light_yellow": (255, 255, 224),
|
299 |
+
"light_pink": (255, 182, 193),
|
300 |
+
"lavender": (230, 230, 250),
|
301 |
+
"peach": (255, 218, 185),
|
302 |
+
"light_cyan": (224, 255, 255),
|
303 |
+
"light_salmon": (255, 160, 122),
|
304 |
+
"light_green_yellow": (173, 255, 47),
|
305 |
+
}
|
306 |
+
|
307 |
+
BORDER_COLORS = ["dynamic"] + list(COLORS.keys())
|
308 |
+
|
309 |
+
|
310 |
+
def calculate_average_color(img):
|
311 |
+
# Resize the image to a small size for faster processing
|
312 |
+
img_small = img.resize((50, 50))
|
313 |
+
# Calculate the average color
|
314 |
+
average_color = img_small.convert("RGB").resize((1, 1)).getpixel((0, 0))
|
315 |
+
return average_color
|
316 |
+
|
317 |
+
|
318 |
+
def add_border_to_image(
|
319 |
+
image_path,
|
320 |
+
target_width,
|
321 |
+
target_height,
|
322 |
+
border_color=None
|
323 |
+
):
|
324 |
+
|
325 |
+
img = Image.open(image_path)
|
326 |
+
|
327 |
+
# Calculate the width and height for the new image with borders
|
328 |
+
original_width, original_height = img.size
|
329 |
+
original_aspect_ratio = original_width / original_height
|
330 |
+
target_aspect_ratio = target_width / target_height
|
331 |
+
|
332 |
+
# Resize the image to fit the target resolution retaining aspect ratio
|
333 |
+
if original_aspect_ratio > target_aspect_ratio:
|
334 |
+
# Image is wider, calculate new height
|
335 |
+
new_height = int(target_width / original_aspect_ratio)
|
336 |
+
resized_img = img.resize((target_width, new_height))
|
337 |
+
else:
|
338 |
+
# Image is taller, calculate new width
|
339 |
+
new_width = int(target_height * original_aspect_ratio)
|
340 |
+
resized_img = img.resize((new_width, target_height))
|
341 |
+
|
342 |
+
# Calculate padding for borders
|
343 |
+
padding = (0, 0, 0, 0)
|
344 |
+
if resized_img.size[0] != target_width or resized_img.size[1] != target_height:
|
345 |
+
if original_aspect_ratio > target_aspect_ratio:
|
346 |
+
# Add borders vertically
|
347 |
+
padding = (0, (target_height - resized_img.size[1]) // 2, 0, (target_height - resized_img.size[1]) // 2)
|
348 |
+
else:
|
349 |
+
# Add borders horizontally
|
350 |
+
padding = ((target_width - resized_img.size[0]) // 2, 0, (target_width - resized_img.size[0]) // 2, 0)
|
351 |
+
|
352 |
+
# Add borders with specified color
|
353 |
+
if not border_color or border_color == "dynamic":
|
354 |
+
border_color = calculate_average_color(resized_img)
|
355 |
+
else:
|
356 |
+
border_color = COLORS.get(border_color, (0, 0, 0))
|
357 |
+
|
358 |
+
bordered_img = ImageOps.expand(resized_img, padding, fill=border_color)
|
359 |
+
|
360 |
+
bordered_img.save(image_path)
|
361 |
+
|
362 |
+
return image_path
|
363 |
+
|
364 |
+
|
365 |
+
def resize_and_position_subimage(
|
366 |
+
subimage,
|
367 |
+
max_width,
|
368 |
+
max_height,
|
369 |
+
subimage_position,
|
370 |
+
main_width,
|
371 |
+
main_height
|
372 |
+
):
|
373 |
+
subimage_width, subimage_height = subimage.size
|
374 |
+
|
375 |
+
# Resize subimage if it exceeds maximum dimensions
|
376 |
+
if subimage_width > max_width or subimage_height > max_height:
|
377 |
+
# Calculate scaling factor
|
378 |
+
width_scale = max_width / subimage_width
|
379 |
+
height_scale = max_height / subimage_height
|
380 |
+
scale = min(width_scale, height_scale)
|
381 |
+
|
382 |
+
# Resize subimage
|
383 |
+
subimage = subimage.resize(
|
384 |
+
(int(subimage_width * scale), int(subimage_height * scale))
|
385 |
+
)
|
386 |
+
|
387 |
+
# Calculate position to place the subimage
|
388 |
+
if subimage_position == "top-left":
|
389 |
+
subimage_x = 0
|
390 |
+
subimage_y = 0
|
391 |
+
elif subimage_position == "top-right":
|
392 |
+
subimage_x = main_width - subimage.width
|
393 |
+
subimage_y = 0
|
394 |
+
elif subimage_position == "bottom-left":
|
395 |
+
subimage_x = 0
|
396 |
+
subimage_y = main_height - subimage.height
|
397 |
+
elif subimage_position == "bottom-right":
|
398 |
+
subimage_x = main_width - subimage.width
|
399 |
+
subimage_y = main_height - subimage.height
|
400 |
+
else:
|
401 |
+
raise ValueError(
|
402 |
+
"Invalid subimage_position. Choose from 'top-left', 'top-right',"
|
403 |
+
" 'bottom-left', or 'bottom-right'."
|
404 |
+
)
|
405 |
+
|
406 |
+
return subimage, subimage_x, subimage_y
|
407 |
+
|
408 |
+
|
409 |
+
def create_image_with_text_and_subimages(
|
410 |
+
text,
|
411 |
+
subimages,
|
412 |
+
width,
|
413 |
+
height,
|
414 |
+
text_color,
|
415 |
+
background_color,
|
416 |
+
output_file
|
417 |
+
):
|
418 |
+
# Create an image with the specified resolution and background color
|
419 |
+
image = Image.new('RGB', (width, height), color=background_color)
|
420 |
+
|
421 |
+
# Initialize ImageDraw object
|
422 |
+
draw = ImageDraw.Draw(image)
|
423 |
+
|
424 |
+
# Load a font
|
425 |
+
font = ImageFont.load_default() # You can specify your font file here
|
426 |
+
|
427 |
+
# Calculate text size and position
|
428 |
+
text_bbox = draw.textbbox((0, 0), text, font=font)
|
429 |
+
text_width = text_bbox[2] - text_bbox[0]
|
430 |
+
text_height = text_bbox[3] - text_bbox[1]
|
431 |
+
text_x = (width - text_width) / 2
|
432 |
+
text_y = (height - text_height) / 2
|
433 |
+
|
434 |
+
# Draw text on the image
|
435 |
+
draw.text((text_x, text_y), text, fill=text_color, font=font)
|
436 |
+
|
437 |
+
# Paste subimages onto the main image
|
438 |
+
for subimage_path, subimage_position in subimages:
|
439 |
+
# Open the subimage
|
440 |
+
subimage = Image.open(subimage_path)
|
441 |
+
|
442 |
+
# Convert subimage to RGBA mode if it doesn't have an alpha channel
|
443 |
+
if subimage.mode != 'RGBA':
|
444 |
+
subimage = subimage.convert('RGBA')
|
445 |
+
|
446 |
+
# Resize and position the subimage
|
447 |
+
subimage, subimage_x, subimage_y = resize_and_position_subimage(
|
448 |
+
subimage, width / 4, height / 4, subimage_position, width, height
|
449 |
+
)
|
450 |
+
|
451 |
+
# Paste the subimage onto the main image
|
452 |
+
image.paste(subimage, (int(subimage_x), int(subimage_y)), subimage)
|
453 |
+
|
454 |
+
image.save(output_file)
|
455 |
+
|
456 |
+
return output_file
|
457 |
+
|
458 |
+
|
459 |
+
def doc_to_txtximg_pages(
|
460 |
+
document,
|
461 |
+
width,
|
462 |
+
height,
|
463 |
+
start_page,
|
464 |
+
end_page,
|
465 |
+
bcolor
|
466 |
+
):
|
467 |
+
from pypdf import PdfReader
|
468 |
+
|
469 |
+
images_folder = "pdf_images/"
|
470 |
+
os.makedirs(images_folder, exist_ok=True)
|
471 |
+
remove_directory_contents(images_folder)
|
472 |
+
|
473 |
+
# First image
|
474 |
+
text_image = os.path.basename(document)[:-4]
|
475 |
+
subimages = [("./assets/logo.jpeg", "top-left")]
|
476 |
+
text_color = (255, 255, 255) if bcolor == "black" else (0, 0, 0) # w|b
|
477 |
+
background_color = COLORS.get(bcolor, (255, 255, 255)) # dynamic white
|
478 |
+
first_image = "pdf_images/0000_00_aaa.png"
|
479 |
+
|
480 |
+
create_image_with_text_and_subimages(
|
481 |
+
text_image,
|
482 |
+
subimages,
|
483 |
+
width,
|
484 |
+
height,
|
485 |
+
text_color,
|
486 |
+
background_color,
|
487 |
+
first_image
|
488 |
+
)
|
489 |
+
|
490 |
+
reader = PdfReader(document)
|
491 |
+
logger.debug(f"Total pages: {reader.get_num_pages()}")
|
492 |
+
|
493 |
+
start_page_idx = max((start_page-1), 0)
|
494 |
+
end_page_inx = min((end_page), (reader.get_num_pages()))
|
495 |
+
document_pages = reader.pages[start_page_idx:end_page_inx]
|
496 |
+
|
497 |
+
logger.info(
|
498 |
+
f"Selected pages from {start_page_idx} to {end_page_inx}: "
|
499 |
+
f"{len(document_pages)}"
|
500 |
+
)
|
501 |
+
|
502 |
+
data_doc = {}
|
503 |
+
for i, page in enumerate(document_pages):
|
504 |
+
|
505 |
+
count = 0
|
506 |
+
images = []
|
507 |
+
for image_file_object in page.images:
|
508 |
+
img_name = f"{images_folder}{i:04d}_{count:02d}_{image_file_object.name}"
|
509 |
+
images.append(img_name)
|
510 |
+
with open(img_name, "wb") as fp:
|
511 |
+
fp.write(image_file_object.data)
|
512 |
+
count += 1
|
513 |
+
img_name = add_border_to_image(img_name, width, height, bcolor)
|
514 |
+
|
515 |
+
data_doc[i] = {
|
516 |
+
"text": remove_hyphens(page.extract_text()),
|
517 |
+
"images": images
|
518 |
+
}
|
519 |
+
|
520 |
+
return data_doc
|
521 |
+
|
522 |
+
|
523 |
+
def page_data_to_segments(result_text=None, chunk_size=None):
|
524 |
+
|
525 |
+
if not chunk_size:
|
526 |
+
chunk_size = 100
|
527 |
+
|
528 |
+
segments_chunks = []
|
529 |
+
time_global = 0
|
530 |
+
for page, result_data in result_text.items():
|
531 |
+
# result_image = result_data["images"]
|
532 |
+
result_text = result_data["text"]
|
533 |
+
text_chunks = split_text_into_chunks(result_text, chunk_size)
|
534 |
+
if not text_chunks:
|
535 |
+
text_chunks = [" "]
|
536 |
+
|
537 |
+
for chunk in text_chunks:
|
538 |
+
chunk_dict = {
|
539 |
+
"text": chunk,
|
540 |
+
"start": (1.0 + time_global),
|
541 |
+
"end": (2.0 + time_global),
|
542 |
+
"speaker": "SPEAKER_00",
|
543 |
+
"page": page,
|
544 |
+
}
|
545 |
+
segments_chunks.append(chunk_dict)
|
546 |
+
time_global += 1
|
547 |
+
|
548 |
+
result_diarize = {"segments": segments_chunks}
|
549 |
+
|
550 |
+
return result_diarize
|
551 |
+
|
552 |
+
|
553 |
+
def update_page_data(result_diarize, doc_data):
|
554 |
+
complete_text = ""
|
555 |
+
current_page = result_diarize["segments"][0]["page"]
|
556 |
+
text_page = ""
|
557 |
+
|
558 |
+
for seg in result_diarize["segments"]:
|
559 |
+
text = seg["text"] + " " # issue
|
560 |
+
complete_text += text
|
561 |
+
|
562 |
+
page = seg["page"]
|
563 |
+
|
564 |
+
if page == current_page:
|
565 |
+
text_page += text
|
566 |
+
else:
|
567 |
+
doc_data[current_page]["text"] = text_page
|
568 |
+
|
569 |
+
# Next
|
570 |
+
text_page = text
|
571 |
+
current_page = page
|
572 |
+
|
573 |
+
if doc_data[current_page]["text"] != text_page:
|
574 |
+
doc_data[current_page]["text"] = text_page
|
575 |
+
|
576 |
+
return doc_data
|
577 |
+
|
578 |
+
|
579 |
+
def fix_timestamps_docs(result_diarize, audio_files):
|
580 |
+
current_start = 0.0
|
581 |
+
|
582 |
+
for seg, audio in zip(result_diarize["segments"], audio_files):
|
583 |
+
duration = round(sf.info(audio).duration, 2)
|
584 |
+
|
585 |
+
seg["start"] = current_start
|
586 |
+
current_start += duration
|
587 |
+
seg["end"] = current_start
|
588 |
+
|
589 |
+
return result_diarize
|
590 |
+
|
591 |
+
|
592 |
+
def create_video_from_images(
|
593 |
+
doc_data,
|
594 |
+
result_diarize
|
595 |
+
):
|
596 |
+
|
597 |
+
# First image path
|
598 |
+
first_image = "pdf_images/0000_00_aaa.png"
|
599 |
+
|
600 |
+
# Time segments and images
|
601 |
+
max_pages_idx = len(doc_data) - 1
|
602 |
+
current_page = result_diarize["segments"][0]["page"]
|
603 |
+
duration_page = 0.0
|
604 |
+
last_image = None
|
605 |
+
|
606 |
+
for seg in result_diarize["segments"]:
|
607 |
+
start = seg["start"]
|
608 |
+
end = seg["end"]
|
609 |
+
duration_seg = end - start
|
610 |
+
|
611 |
+
page = seg["page"]
|
612 |
+
|
613 |
+
if page == current_page:
|
614 |
+
duration_page += duration_seg
|
615 |
+
else:
|
616 |
+
|
617 |
+
images = doc_data[current_page]["images"]
|
618 |
+
|
619 |
+
if first_image:
|
620 |
+
images = [first_image] + images
|
621 |
+
first_image = None
|
622 |
+
if not doc_data[min(max_pages_idx, (current_page+1))]["text"].strip():
|
623 |
+
images = images + doc_data[min(max_pages_idx, (current_page+1))]["images"]
|
624 |
+
if not images and last_image:
|
625 |
+
images = [last_image]
|
626 |
+
|
627 |
+
# Calculate images duration
|
628 |
+
time_duration_per_image = round((duration_page / len(images)), 2)
|
629 |
+
doc_data[current_page]["time_per_image"] = time_duration_per_image
|
630 |
+
|
631 |
+
# Next values
|
632 |
+
doc_data[current_page]["images"] = images
|
633 |
+
last_image = images[-1]
|
634 |
+
duration_page = duration_seg
|
635 |
+
current_page = page
|
636 |
+
|
637 |
+
if "time_per_image" not in doc_data[current_page].keys():
|
638 |
+
images = doc_data[current_page]["images"]
|
639 |
+
if first_image:
|
640 |
+
images = [first_image] + images
|
641 |
+
if not images:
|
642 |
+
images = [last_image]
|
643 |
+
time_duration_per_image = round((duration_page / len(images)), 2)
|
644 |
+
doc_data[current_page]["time_per_image"] = time_duration_per_image
|
645 |
+
|
646 |
+
# Timestamped image video.
|
647 |
+
with open("list.txt", "w") as file:
|
648 |
+
|
649 |
+
for i, page in enumerate(doc_data.values()):
|
650 |
+
|
651 |
+
duration = page["time_per_image"]
|
652 |
+
for img in page["images"]:
|
653 |
+
if i == len(doc_data) - 1 and img == page["images"][-1]: # Check if it's the last item
|
654 |
+
file.write(f"file {img}\n")
|
655 |
+
file.write(f"outpoint {duration}")
|
656 |
+
else:
|
657 |
+
file.write(f"file {img}\n")
|
658 |
+
file.write(f"outpoint {duration}\n")
|
659 |
+
|
660 |
+
out_video = "video_from_images.mp4"
|
661 |
+
remove_files(out_video)
|
662 |
+
|
663 |
+
cm = f"ffmpeg -y -f concat -i list.txt -c:v libx264 -preset veryfast -crf 18 -pix_fmt yuv420p {out_video}"
|
664 |
+
cm_alt = f"ffmpeg -f concat -i list.txt -c:v libx264 -r 30 -pix_fmt yuv420p -y {out_video}"
|
665 |
+
try:
|
666 |
+
run_command(cm)
|
667 |
+
except Exception as error:
|
668 |
+
logger.error(str(error))
|
669 |
+
remove_files(out_video)
|
670 |
+
run_command(cm_alt)
|
671 |
+
|
672 |
+
return out_video
|
673 |
+
|
674 |
+
|
675 |
+
def merge_video_and_audio(video_doc, final_wav_file):
|
676 |
+
|
677 |
+
fixed_audio = "fixed_audio.mp3"
|
678 |
+
remove_files(fixed_audio)
|
679 |
+
cm = f"ffmpeg -i {final_wav_file} -c:a libmp3lame {fixed_audio}"
|
680 |
+
run_command(cm)
|
681 |
+
|
682 |
+
vid_out = "video_book.mp4"
|
683 |
+
remove_files(vid_out)
|
684 |
+
cm = f"ffmpeg -i {video_doc} -i {fixed_audio} -c:v copy -c:a copy -map 0:v -map 1:a -shortest {vid_out}"
|
685 |
+
run_command(cm)
|
686 |
+
|
687 |
+
return vid_out
|
688 |
+
|
689 |
+
|
690 |
+
# subtitles
|
691 |
+
|
692 |
+
|
693 |
+
def get_subtitle(
|
694 |
+
language,
|
695 |
+
segments_data,
|
696 |
+
extension,
|
697 |
+
filename=None,
|
698 |
+
highlight_words=False,
|
699 |
+
):
|
700 |
+
if not filename:
|
701 |
+
filename = "task_subtitle"
|
702 |
+
|
703 |
+
is_ass_extension = False
|
704 |
+
if extension == "ass":
|
705 |
+
is_ass_extension = True
|
706 |
+
extension = "srt"
|
707 |
+
|
708 |
+
sub_file = filename + "." + extension
|
709 |
+
support_name = filename + ".mp3"
|
710 |
+
remove_files(sub_file)
|
711 |
+
|
712 |
+
writer = get_writer(extension, output_dir=".")
|
713 |
+
word_options = {
|
714 |
+
"highlight_words": highlight_words,
|
715 |
+
"max_line_count": None,
|
716 |
+
"max_line_width": None,
|
717 |
+
}
|
718 |
+
|
719 |
+
# Get data subs
|
720 |
+
subtitle_data = copy.deepcopy(segments_data)
|
721 |
+
subtitle_data["language"] = (
|
722 |
+
"ja" if language in ["ja", "zh", "zh-TW"] else language
|
723 |
+
)
|
724 |
+
|
725 |
+
# Clean
|
726 |
+
if not highlight_words:
|
727 |
+
subtitle_data.pop("word_segments", None)
|
728 |
+
for segment in subtitle_data["segments"]:
|
729 |
+
for key in ["speaker", "chars", "words"]:
|
730 |
+
segment.pop(key, None)
|
731 |
+
|
732 |
+
writer(
|
733 |
+
subtitle_data,
|
734 |
+
support_name,
|
735 |
+
word_options,
|
736 |
+
)
|
737 |
+
|
738 |
+
if is_ass_extension:
|
739 |
+
temp_name = filename + ".ass"
|
740 |
+
remove_files(temp_name)
|
741 |
+
convert_sub = f'ffmpeg -i "{sub_file}" "{temp_name}" -y'
|
742 |
+
run_command(convert_sub)
|
743 |
+
sub_file = temp_name
|
744 |
+
|
745 |
+
return sub_file
|
746 |
+
|
747 |
+
|
748 |
+
def process_subtitles(
|
749 |
+
deep_copied_result,
|
750 |
+
align_language,
|
751 |
+
result_diarize,
|
752 |
+
output_format_subtitle,
|
753 |
+
TRANSLATE_AUDIO_TO,
|
754 |
+
):
|
755 |
+
name_ori = "sub_ori."
|
756 |
+
name_tra = "sub_tra."
|
757 |
+
remove_files(
|
758 |
+
[name_ori + output_format_subtitle, name_tra + output_format_subtitle]
|
759 |
+
)
|
760 |
+
|
761 |
+
writer = get_writer(output_format_subtitle, output_dir=".")
|
762 |
+
word_options = {
|
763 |
+
"highlight_words": False,
|
764 |
+
"max_line_count": None,
|
765 |
+
"max_line_width": None,
|
766 |
+
}
|
767 |
+
|
768 |
+
# original lang
|
769 |
+
subs_copy_result = copy.deepcopy(deep_copied_result)
|
770 |
+
subs_copy_result["language"] = (
|
771 |
+
"zh" if align_language == "zh-TW" else align_language
|
772 |
+
)
|
773 |
+
for segment in subs_copy_result["segments"]:
|
774 |
+
segment.pop("speaker", None)
|
775 |
+
|
776 |
+
try:
|
777 |
+
writer(
|
778 |
+
subs_copy_result,
|
779 |
+
name_ori[:-1] + ".mp3",
|
780 |
+
word_options,
|
781 |
+
)
|
782 |
+
except Exception as error:
|
783 |
+
logger.error(str(error))
|
784 |
+
if str(error) == "list indices must be integers or slices, not str":
|
785 |
+
logger.error(
|
786 |
+
"Related to poor word segmentation"
|
787 |
+
" in segments after alignment."
|
788 |
+
)
|
789 |
+
subs_copy_result["segments"][0].pop("words")
|
790 |
+
writer(
|
791 |
+
subs_copy_result,
|
792 |
+
name_ori[:-1] + ".mp3",
|
793 |
+
word_options,
|
794 |
+
)
|
795 |
+
|
796 |
+
# translated lang
|
797 |
+
subs_tra_copy_result = copy.deepcopy(result_diarize)
|
798 |
+
subs_tra_copy_result["language"] = (
|
799 |
+
"ja" if TRANSLATE_AUDIO_TO in ["ja", "zh", "zh-TW"] else align_language
|
800 |
+
)
|
801 |
+
subs_tra_copy_result.pop("word_segments", None)
|
802 |
+
for segment in subs_tra_copy_result["segments"]:
|
803 |
+
for key in ["speaker", "chars", "words"]:
|
804 |
+
segment.pop(key, None)
|
805 |
+
|
806 |
+
writer(
|
807 |
+
subs_tra_copy_result,
|
808 |
+
name_tra[:-1] + ".mp3",
|
809 |
+
word_options,
|
810 |
+
)
|
811 |
+
|
812 |
+
return name_tra + output_format_subtitle
|
813 |
+
|
814 |
+
|
815 |
+
def linguistic_level_segments(
|
816 |
+
result_base,
|
817 |
+
linguistic_unit="word", # word or char
|
818 |
+
):
|
819 |
+
linguistic_unit = linguistic_unit[:4]
|
820 |
+
linguistic_unit_key = linguistic_unit + "s"
|
821 |
+
result = copy.deepcopy(result_base)
|
822 |
+
|
823 |
+
if linguistic_unit_key not in result["segments"][0].keys():
|
824 |
+
raise ValueError("No alignment detected, can't process")
|
825 |
+
|
826 |
+
segments_by_unit = []
|
827 |
+
for segment in result["segments"]:
|
828 |
+
segment_units = segment[linguistic_unit_key]
|
829 |
+
# segment_speaker = segment.get("speaker", "SPEAKER_00")
|
830 |
+
|
831 |
+
for unit in segment_units:
|
832 |
+
|
833 |
+
text = unit[linguistic_unit]
|
834 |
+
|
835 |
+
if "start" in unit.keys():
|
836 |
+
segments_by_unit.append(
|
837 |
+
{
|
838 |
+
"start": unit["start"],
|
839 |
+
"end": unit["end"],
|
840 |
+
"text": text,
|
841 |
+
# "speaker": segment_speaker,
|
842 |
+
}
|
843 |
+
)
|
844 |
+
elif not segments_by_unit:
|
845 |
+
pass
|
846 |
+
else:
|
847 |
+
segments_by_unit[-1]["text"] += text
|
848 |
+
|
849 |
+
return {"segments": segments_by_unit}
|
850 |
+
|
851 |
+
|
852 |
+
def break_aling_segments(
|
853 |
+
result: dict,
|
854 |
+
break_characters: str = "", # ":|,|.|"
|
855 |
+
):
|
856 |
+
result_align = copy.deepcopy(result)
|
857 |
+
|
858 |
+
break_characters_list = break_characters.split("|")
|
859 |
+
break_characters_list = [i for i in break_characters_list if i != '']
|
860 |
+
|
861 |
+
if not break_characters_list:
|
862 |
+
logger.info("No valid break characters were specified.")
|
863 |
+
return result
|
864 |
+
|
865 |
+
logger.info(f"Redivide text segments by: {str(break_characters_list)}")
|
866 |
+
|
867 |
+
# create new with filters
|
868 |
+
normal = []
|
869 |
+
|
870 |
+
def process_chars(chars, letter_new_start, num, text):
|
871 |
+
start_key, end_key = "start", "end"
|
872 |
+
start_value = end_value = None
|
873 |
+
|
874 |
+
for char in chars:
|
875 |
+
if start_key in char:
|
876 |
+
start_value = char[start_key]
|
877 |
+
break
|
878 |
+
|
879 |
+
for char in reversed(chars):
|
880 |
+
if end_key in char:
|
881 |
+
end_value = char[end_key]
|
882 |
+
break
|
883 |
+
|
884 |
+
if not start_value or not end_value:
|
885 |
+
raise Exception(
|
886 |
+
f"Unable to obtain a valid timestamp for chars: {str(chars)}"
|
887 |
+
)
|
888 |
+
|
889 |
+
return {
|
890 |
+
"start": start_value,
|
891 |
+
"end": end_value,
|
892 |
+
"text": text,
|
893 |
+
"words": chars,
|
894 |
+
}
|
895 |
+
|
896 |
+
for i, segment in enumerate(result_align['segments']):
|
897 |
+
|
898 |
+
logger.debug(f"- Process segment: {i}, text: {segment['text']}")
|
899 |
+
# start = segment['start']
|
900 |
+
letter_new_start = 0
|
901 |
+
for num, char in enumerate(segment['chars']):
|
902 |
+
|
903 |
+
if char["char"] is None:
|
904 |
+
continue
|
905 |
+
|
906 |
+
# if "start" in char:
|
907 |
+
# start = char["start"]
|
908 |
+
|
909 |
+
# if "end" in char:
|
910 |
+
# end = char["end"]
|
911 |
+
|
912 |
+
# Break by character
|
913 |
+
if char['char'] in break_characters_list:
|
914 |
+
|
915 |
+
text = segment['text'][letter_new_start:num+1]
|
916 |
+
|
917 |
+
logger.debug(
|
918 |
+
f"Break in: {char['char']}, position: {num}, text: {text}"
|
919 |
+
)
|
920 |
+
|
921 |
+
chars = segment['chars'][letter_new_start:num+1]
|
922 |
+
|
923 |
+
if not text:
|
924 |
+
logger.debug("No text")
|
925 |
+
continue
|
926 |
+
|
927 |
+
if num == 0 and not text.strip():
|
928 |
+
logger.debug("blank space in start")
|
929 |
+
continue
|
930 |
+
|
931 |
+
if len(text) == 1:
|
932 |
+
logger.debug(f"Short char append, num: {num}")
|
933 |
+
normal[-1]["text"] += text
|
934 |
+
normal[-1]["words"].append(chars)
|
935 |
+
continue
|
936 |
+
|
937 |
+
# logger.debug(chars)
|
938 |
+
normal_dict = process_chars(chars, letter_new_start, num, text)
|
939 |
+
|
940 |
+
letter_new_start = num+1
|
941 |
+
|
942 |
+
normal.append(normal_dict)
|
943 |
+
|
944 |
+
# If we reach the end of the segment, add the last part of chars.
|
945 |
+
if num == len(segment["chars"]) - 1:
|
946 |
+
|
947 |
+
text = segment['text'][letter_new_start:num+1]
|
948 |
+
|
949 |
+
# If remain text len is not default len text
|
950 |
+
if num not in [len(text)-1, len(text)] and text:
|
951 |
+
logger.debug(f'Remaining text: {text}')
|
952 |
+
|
953 |
+
if not text:
|
954 |
+
logger.debug("No remaining text.")
|
955 |
+
continue
|
956 |
+
|
957 |
+
if len(text) == 1:
|
958 |
+
logger.debug(f"Short char append, num: {num}")
|
959 |
+
normal[-1]["text"] += text
|
960 |
+
normal[-1]["words"].append(chars)
|
961 |
+
continue
|
962 |
+
|
963 |
+
chars = segment['chars'][letter_new_start:num+1]
|
964 |
+
|
965 |
+
normal_dict = process_chars(chars, letter_new_start, num, text)
|
966 |
+
|
967 |
+
letter_new_start = num+1
|
968 |
+
|
969 |
+
normal.append(normal_dict)
|
970 |
+
|
971 |
+
# Rename char to word
|
972 |
+
for item in normal:
|
973 |
+
words_list = item['words']
|
974 |
+
for word_item in words_list:
|
975 |
+
if 'char' in word_item:
|
976 |
+
word_item['word'] = word_item.pop('char')
|
977 |
+
|
978 |
+
# Convert to dict default
|
979 |
+
break_segments = {"segments": normal}
|
980 |
+
|
981 |
+
msg_count = (
|
982 |
+
f"Segment count before: {len(result['segments'])}, "
|
983 |
+
f"after: {len(break_segments['segments'])}."
|
984 |
+
)
|
985 |
+
logger.info(msg_count)
|
986 |
+
|
987 |
+
return break_segments
|
soni_translate/text_to_speech.py
ADDED
@@ -0,0 +1,1574 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from gtts import gTTS
|
2 |
+
import edge_tts, asyncio, json, glob # noqa
|
3 |
+
from tqdm import tqdm
|
4 |
+
import librosa, os, re, torch, gc, subprocess # noqa
|
5 |
+
from .language_configuration import (
|
6 |
+
fix_code_language,
|
7 |
+
BARK_VOICES_LIST,
|
8 |
+
VITS_VOICES_LIST,
|
9 |
+
)
|
10 |
+
from .utils import (
|
11 |
+
download_manager,
|
12 |
+
create_directories,
|
13 |
+
copy_files,
|
14 |
+
rename_file,
|
15 |
+
remove_directory_contents,
|
16 |
+
remove_files,
|
17 |
+
run_command,
|
18 |
+
)
|
19 |
+
import numpy as np
|
20 |
+
from typing import Any, Dict
|
21 |
+
from pathlib import Path
|
22 |
+
import soundfile as sf
|
23 |
+
import platform
|
24 |
+
import logging
|
25 |
+
import traceback
|
26 |
+
from .logging_setup import logger
|
27 |
+
|
28 |
+
|
29 |
+
class TTS_OperationError(Exception):
|
30 |
+
def __init__(self, message="The operation did not complete successfully."):
|
31 |
+
self.message = message
|
32 |
+
super().__init__(self.message)
|
33 |
+
|
34 |
+
|
35 |
+
def verify_saved_file_and_size(filename):
|
36 |
+
if not os.path.exists(filename):
|
37 |
+
raise TTS_OperationError(f"File '{filename}' was not saved.")
|
38 |
+
if os.path.getsize(filename) == 0:
|
39 |
+
raise TTS_OperationError(
|
40 |
+
f"File '{filename}' has a zero size. "
|
41 |
+
"Related to incorrect TTS for the target language"
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
def error_handling_in_tts(error, segment, TRANSLATE_AUDIO_TO, filename):
|
46 |
+
traceback.print_exc()
|
47 |
+
logger.error(f"Error: {str(error)}")
|
48 |
+
try:
|
49 |
+
from tempfile import TemporaryFile
|
50 |
+
|
51 |
+
tts = gTTS(segment["text"], lang=fix_code_language(TRANSLATE_AUDIO_TO))
|
52 |
+
# tts.save(filename)
|
53 |
+
f = TemporaryFile()
|
54 |
+
tts.write_to_fp(f)
|
55 |
+
|
56 |
+
# Reset the file pointer to the beginning of the file
|
57 |
+
f.seek(0)
|
58 |
+
|
59 |
+
# Read audio data from the TemporaryFile using soundfile
|
60 |
+
audio_data, samplerate = sf.read(f)
|
61 |
+
f.close() # Close the TemporaryFile
|
62 |
+
sf.write(
|
63 |
+
filename, audio_data, samplerate, format="ogg", subtype="vorbis"
|
64 |
+
)
|
65 |
+
|
66 |
+
logger.warning(
|
67 |
+
'TTS auxiliary will be utilized '
|
68 |
+
f'rather than TTS: {segment["tts_name"]}'
|
69 |
+
)
|
70 |
+
verify_saved_file_and_size(filename)
|
71 |
+
except Exception as error:
|
72 |
+
logger.critical(f"Error: {str(error)}")
|
73 |
+
sample_rate_aux = 22050
|
74 |
+
duration = float(segment["end"]) - float(segment["start"])
|
75 |
+
data = np.zeros(int(sample_rate_aux * duration)).astype(np.float32)
|
76 |
+
sf.write(
|
77 |
+
filename, data, sample_rate_aux, format="ogg", subtype="vorbis"
|
78 |
+
)
|
79 |
+
logger.error("Audio will be replaced -> [silent audio].")
|
80 |
+
verify_saved_file_and_size(filename)
|
81 |
+
|
82 |
+
|
83 |
+
def pad_array(array, sr):
|
84 |
+
|
85 |
+
if isinstance(array, list):
|
86 |
+
array = np.array(array)
|
87 |
+
|
88 |
+
if not array.shape[0]:
|
89 |
+
raise ValueError("The generated audio does not contain any data")
|
90 |
+
|
91 |
+
valid_indices = np.where(np.abs(array) > 0.001)[0]
|
92 |
+
|
93 |
+
if len(valid_indices) == 0:
|
94 |
+
logger.debug(f"No valid indices: {array}")
|
95 |
+
return array
|
96 |
+
|
97 |
+
try:
|
98 |
+
pad_indice = int(0.1 * sr)
|
99 |
+
start_pad = max(0, valid_indices[0] - pad_indice)
|
100 |
+
end_pad = min(len(array), valid_indices[-1] + 1 + pad_indice)
|
101 |
+
padded_array = array[start_pad:end_pad]
|
102 |
+
return padded_array
|
103 |
+
except Exception as error:
|
104 |
+
logger.error(str(error))
|
105 |
+
return array
|
106 |
+
|
107 |
+
|
108 |
+
# =====================================
|
109 |
+
# EDGE TTS
|
110 |
+
# =====================================
|
111 |
+
|
112 |
+
|
113 |
+
def edge_tts_voices_list():
|
114 |
+
try:
|
115 |
+
completed_process = subprocess.run(
|
116 |
+
["edge-tts", "--list-voices"], capture_output=True, text=True
|
117 |
+
)
|
118 |
+
lines = completed_process.stdout.strip().split("\n")
|
119 |
+
except Exception as error:
|
120 |
+
logger.debug(str(error))
|
121 |
+
lines = []
|
122 |
+
|
123 |
+
voices = []
|
124 |
+
for line in lines:
|
125 |
+
if line.startswith("Name: "):
|
126 |
+
voice_entry = {}
|
127 |
+
voice_entry["Name"] = line.split(": ")[1]
|
128 |
+
elif line.startswith("Gender: "):
|
129 |
+
voice_entry["Gender"] = line.split(": ")[1]
|
130 |
+
voices.append(voice_entry)
|
131 |
+
|
132 |
+
formatted_voices = [
|
133 |
+
f"{entry['Name']}-{entry['Gender']}" for entry in voices
|
134 |
+
]
|
135 |
+
|
136 |
+
if not formatted_voices:
|
137 |
+
logger.warning(
|
138 |
+
"The list of Edge TTS voices could not be obtained, "
|
139 |
+
"switching to an alternative method"
|
140 |
+
)
|
141 |
+
tts_voice_list = asyncio.new_event_loop().run_until_complete(
|
142 |
+
edge_tts.list_voices()
|
143 |
+
)
|
144 |
+
formatted_voices = sorted(
|
145 |
+
[f"{v['ShortName']}-{v['Gender']}" for v in tts_voice_list]
|
146 |
+
)
|
147 |
+
|
148 |
+
if not formatted_voices:
|
149 |
+
logger.error("Can't get EDGE TTS - list voices")
|
150 |
+
|
151 |
+
return formatted_voices
|
152 |
+
|
153 |
+
|
154 |
+
def segments_egde_tts(filtered_edge_segments, TRANSLATE_AUDIO_TO, is_gui):
|
155 |
+
for segment in tqdm(filtered_edge_segments["segments"]):
|
156 |
+
speaker = segment["speaker"] # noqa
|
157 |
+
text = segment["text"]
|
158 |
+
start = segment["start"]
|
159 |
+
tts_name = segment["tts_name"]
|
160 |
+
|
161 |
+
# make the tts audio
|
162 |
+
filename = f"audio/{start}.ogg"
|
163 |
+
temp_file = filename[:-3] + "mp3"
|
164 |
+
|
165 |
+
logger.info(f"{text} >> {filename}")
|
166 |
+
try:
|
167 |
+
if is_gui:
|
168 |
+
asyncio.run(
|
169 |
+
edge_tts.Communicate(
|
170 |
+
text, "-".join(tts_name.split("-")[:-1])
|
171 |
+
).save(temp_file)
|
172 |
+
)
|
173 |
+
else:
|
174 |
+
# nest_asyncio.apply() if not is_gui else None
|
175 |
+
command = f'edge-tts -t "{text}" -v "{tts_name.replace("-Male", "").replace("-Female", "")}" --write-media "{temp_file}"'
|
176 |
+
run_command(command)
|
177 |
+
verify_saved_file_and_size(temp_file)
|
178 |
+
|
179 |
+
data, sample_rate = sf.read(temp_file)
|
180 |
+
data = pad_array(data, sample_rate)
|
181 |
+
# os.remove(temp_file)
|
182 |
+
|
183 |
+
# Save file
|
184 |
+
sf.write(
|
185 |
+
file=filename,
|
186 |
+
samplerate=sample_rate,
|
187 |
+
data=data,
|
188 |
+
format="ogg",
|
189 |
+
subtype="vorbis",
|
190 |
+
)
|
191 |
+
verify_saved_file_and_size(filename)
|
192 |
+
|
193 |
+
except Exception as error:
|
194 |
+
error_handling_in_tts(error, segment, TRANSLATE_AUDIO_TO, filename)
|
195 |
+
|
196 |
+
|
197 |
+
# =====================================
|
198 |
+
# BARK TTS
|
199 |
+
# =====================================
|
200 |
+
|
201 |
+
|
202 |
+
def segments_bark_tts(
|
203 |
+
filtered_bark_segments, TRANSLATE_AUDIO_TO, model_id_bark="suno/bark-small"
|
204 |
+
):
|
205 |
+
from transformers import AutoProcessor, BarkModel
|
206 |
+
from optimum.bettertransformer import BetterTransformer
|
207 |
+
|
208 |
+
device = os.environ.get("SONITR_DEVICE")
|
209 |
+
torch_dtype_env = torch.float16 if device == "cuda" else torch.float32
|
210 |
+
|
211 |
+
# load model bark
|
212 |
+
model = BarkModel.from_pretrained(
|
213 |
+
model_id_bark, torch_dtype=torch_dtype_env
|
214 |
+
).to(device)
|
215 |
+
model = model.to(device)
|
216 |
+
processor = AutoProcessor.from_pretrained(
|
217 |
+
model_id_bark, return_tensors="pt"
|
218 |
+
) # , padding=True
|
219 |
+
if device == "cuda":
|
220 |
+
# convert to bettertransformer
|
221 |
+
model = BetterTransformer.transform(model, keep_original_model=False)
|
222 |
+
# enable CPU offload
|
223 |
+
# model.enable_cpu_offload()
|
224 |
+
sampling_rate = model.generation_config.sample_rate
|
225 |
+
|
226 |
+
# filtered_segments = filtered_bark_segments['segments']
|
227 |
+
# Sorting the segments by 'tts_name'
|
228 |
+
# sorted_segments = sorted(filtered_segments, key=lambda x: x['tts_name'])
|
229 |
+
# logger.debug(sorted_segments)
|
230 |
+
|
231 |
+
for segment in tqdm(filtered_bark_segments["segments"]):
|
232 |
+
speaker = segment["speaker"] # noqa
|
233 |
+
text = segment["text"]
|
234 |
+
start = segment["start"]
|
235 |
+
tts_name = segment["tts_name"]
|
236 |
+
|
237 |
+
inputs = processor(text, voice_preset=BARK_VOICES_LIST[tts_name]).to(
|
238 |
+
device
|
239 |
+
)
|
240 |
+
|
241 |
+
# make the tts audio
|
242 |
+
filename = f"audio/{start}.ogg"
|
243 |
+
logger.info(f"{text} >> {filename}")
|
244 |
+
try:
|
245 |
+
# Infer
|
246 |
+
with torch.inference_mode():
|
247 |
+
speech_output = model.generate(
|
248 |
+
**inputs,
|
249 |
+
do_sample=True,
|
250 |
+
fine_temperature=0.4,
|
251 |
+
coarse_temperature=0.8,
|
252 |
+
pad_token_id=processor.tokenizer.pad_token_id,
|
253 |
+
)
|
254 |
+
# Save file
|
255 |
+
data_tts = pad_array(
|
256 |
+
speech_output.cpu().numpy().squeeze().astype(np.float32),
|
257 |
+
sampling_rate,
|
258 |
+
)
|
259 |
+
sf.write(
|
260 |
+
file=filename,
|
261 |
+
samplerate=sampling_rate,
|
262 |
+
data=data_tts,
|
263 |
+
format="ogg",
|
264 |
+
subtype="vorbis",
|
265 |
+
)
|
266 |
+
verify_saved_file_and_size(filename)
|
267 |
+
except Exception as error:
|
268 |
+
error_handling_in_tts(error, segment, TRANSLATE_AUDIO_TO, filename)
|
269 |
+
gc.collect()
|
270 |
+
torch.cuda.empty_cache()
|
271 |
+
try:
|
272 |
+
del processor
|
273 |
+
del model
|
274 |
+
gc.collect()
|
275 |
+
torch.cuda.empty_cache()
|
276 |
+
except Exception as error:
|
277 |
+
logger.error(str(error))
|
278 |
+
gc.collect()
|
279 |
+
torch.cuda.empty_cache()
|
280 |
+
|
281 |
+
|
282 |
+
# =====================================
|
283 |
+
# VITS TTS
|
284 |
+
# =====================================
|
285 |
+
|
286 |
+
|
287 |
+
def uromanize(input_string):
|
288 |
+
"""Convert non-Roman strings to Roman using the `uroman` perl package."""
|
289 |
+
# script_path = os.path.join(uroman_path, "bin", "uroman.pl")
|
290 |
+
|
291 |
+
if not os.path.exists("./uroman"):
|
292 |
+
logger.info(
|
293 |
+
"Clonning repository uroman https://github.com/isi-nlp/uroman.git"
|
294 |
+
" for romanize the text"
|
295 |
+
)
|
296 |
+
process = subprocess.Popen(
|
297 |
+
["git", "clone", "https://github.com/isi-nlp/uroman.git"],
|
298 |
+
stdout=subprocess.PIPE,
|
299 |
+
stderr=subprocess.PIPE,
|
300 |
+
)
|
301 |
+
stdout, stderr = process.communicate()
|
302 |
+
script_path = os.path.join("./uroman", "bin", "uroman.pl")
|
303 |
+
|
304 |
+
command = ["perl", script_path]
|
305 |
+
|
306 |
+
process = subprocess.Popen(
|
307 |
+
command,
|
308 |
+
stdin=subprocess.PIPE,
|
309 |
+
stdout=subprocess.PIPE,
|
310 |
+
stderr=subprocess.PIPE,
|
311 |
+
)
|
312 |
+
# Execute the perl command
|
313 |
+
stdout, stderr = process.communicate(input=input_string.encode())
|
314 |
+
|
315 |
+
if process.returncode != 0:
|
316 |
+
raise ValueError(f"Error {process.returncode}: {stderr.decode()}")
|
317 |
+
|
318 |
+
# Return the output as a string and skip the new-line character at the end
|
319 |
+
return stdout.decode()[:-1]
|
320 |
+
|
321 |
+
|
322 |
+
def segments_vits_tts(filtered_vits_segments, TRANSLATE_AUDIO_TO):
|
323 |
+
from transformers import VitsModel, AutoTokenizer
|
324 |
+
|
325 |
+
filtered_segments = filtered_vits_segments["segments"]
|
326 |
+
# Sorting the segments by 'tts_name'
|
327 |
+
sorted_segments = sorted(filtered_segments, key=lambda x: x["tts_name"])
|
328 |
+
logger.debug(sorted_segments)
|
329 |
+
|
330 |
+
model_name_key = None
|
331 |
+
for segment in tqdm(sorted_segments):
|
332 |
+
speaker = segment["speaker"] # noqa
|
333 |
+
text = segment["text"]
|
334 |
+
start = segment["start"]
|
335 |
+
tts_name = segment["tts_name"]
|
336 |
+
|
337 |
+
if tts_name != model_name_key:
|
338 |
+
model_name_key = tts_name
|
339 |
+
model = VitsModel.from_pretrained(VITS_VOICES_LIST[tts_name])
|
340 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
341 |
+
VITS_VOICES_LIST[tts_name]
|
342 |
+
)
|
343 |
+
sampling_rate = model.config.sampling_rate
|
344 |
+
|
345 |
+
if tokenizer.is_uroman:
|
346 |
+
romanize_text = uromanize(text)
|
347 |
+
logger.debug(f"Romanize text: {romanize_text}")
|
348 |
+
inputs = tokenizer(romanize_text, return_tensors="pt")
|
349 |
+
else:
|
350 |
+
inputs = tokenizer(text, return_tensors="pt")
|
351 |
+
|
352 |
+
# make the tts audio
|
353 |
+
filename = f"audio/{start}.ogg"
|
354 |
+
logger.info(f"{text} >> {filename}")
|
355 |
+
try:
|
356 |
+
# Infer
|
357 |
+
with torch.no_grad():
|
358 |
+
speech_output = model(**inputs).waveform
|
359 |
+
|
360 |
+
data_tts = pad_array(
|
361 |
+
speech_output.cpu().numpy().squeeze().astype(np.float32),
|
362 |
+
sampling_rate,
|
363 |
+
)
|
364 |
+
# Save file
|
365 |
+
sf.write(
|
366 |
+
file=filename,
|
367 |
+
samplerate=sampling_rate,
|
368 |
+
data=data_tts,
|
369 |
+
format="ogg",
|
370 |
+
subtype="vorbis",
|
371 |
+
)
|
372 |
+
verify_saved_file_and_size(filename)
|
373 |
+
except Exception as error:
|
374 |
+
error_handling_in_tts(error, segment, TRANSLATE_AUDIO_TO, filename)
|
375 |
+
gc.collect()
|
376 |
+
torch.cuda.empty_cache()
|
377 |
+
try:
|
378 |
+
del tokenizer
|
379 |
+
del model
|
380 |
+
gc.collect()
|
381 |
+
torch.cuda.empty_cache()
|
382 |
+
except Exception as error:
|
383 |
+
logger.error(str(error))
|
384 |
+
gc.collect()
|
385 |
+
torch.cuda.empty_cache()
|
386 |
+
|
387 |
+
|
388 |
+
# =====================================
|
389 |
+
# Coqui XTTS
|
390 |
+
# =====================================
|
391 |
+
|
392 |
+
|
393 |
+
def coqui_xtts_voices_list():
|
394 |
+
main_folder = "_XTTS_"
|
395 |
+
pattern_coqui = re.compile(r".+\.(wav|mp3|ogg|m4a)$")
|
396 |
+
pattern_automatic_speaker = re.compile(r"AUTOMATIC_SPEAKER_\d+\.wav$")
|
397 |
+
|
398 |
+
# List only files in the directory matching the pattern but not matching
|
399 |
+
# AUTOMATIC_SPEAKER_00.wav, AUTOMATIC_SPEAKER_01.wav, etc.
|
400 |
+
wav_voices = [
|
401 |
+
"_XTTS_/" + f
|
402 |
+
for f in os.listdir(main_folder)
|
403 |
+
if os.path.isfile(os.path.join(main_folder, f))
|
404 |
+
and pattern_coqui.match(f)
|
405 |
+
and not pattern_automatic_speaker.match(f)
|
406 |
+
]
|
407 |
+
|
408 |
+
return ["_XTTS_/AUTOMATIC.wav"] + wav_voices
|
409 |
+
|
410 |
+
|
411 |
+
def seconds_to_hhmmss_ms(seconds):
|
412 |
+
hours = seconds // 3600
|
413 |
+
minutes = (seconds % 3600) // 60
|
414 |
+
seconds = seconds % 60
|
415 |
+
milliseconds = int((seconds - int(seconds)) * 1000)
|
416 |
+
return "%02d:%02d:%02d.%03d" % (hours, minutes, int(seconds), milliseconds)
|
417 |
+
|
418 |
+
|
419 |
+
def audio_trimming(audio_path, destination, start, end):
|
420 |
+
if isinstance(start, (int, float)):
|
421 |
+
start = seconds_to_hhmmss_ms(start)
|
422 |
+
if isinstance(end, (int, float)):
|
423 |
+
end = seconds_to_hhmmss_ms(end)
|
424 |
+
|
425 |
+
if destination:
|
426 |
+
file_directory = destination
|
427 |
+
else:
|
428 |
+
file_directory = os.path.dirname(audio_path)
|
429 |
+
|
430 |
+
file_name = os.path.splitext(os.path.basename(audio_path))[0]
|
431 |
+
file_ = f"{file_name}_trim.wav"
|
432 |
+
# file_ = f'{os.path.splitext(audio_path)[0]}_trim.wav'
|
433 |
+
output_path = os.path.join(file_directory, file_)
|
434 |
+
|
435 |
+
# -t (duration from -ss) | -to (time stop) | -af silenceremove=1:0:-50dB (remove silence)
|
436 |
+
command = f'ffmpeg -y -loglevel error -i "{audio_path}" -ss {start} -to {end} -acodec pcm_s16le -f wav "{output_path}"'
|
437 |
+
run_command(command)
|
438 |
+
|
439 |
+
return output_path
|
440 |
+
|
441 |
+
|
442 |
+
def convert_to_xtts_good_sample(audio_path: str = "", destination: str = ""):
|
443 |
+
if destination:
|
444 |
+
file_directory = destination
|
445 |
+
else:
|
446 |
+
file_directory = os.path.dirname(audio_path)
|
447 |
+
|
448 |
+
file_name = os.path.splitext(os.path.basename(audio_path))[0]
|
449 |
+
file_ = f"{file_name}_good_sample.wav"
|
450 |
+
# file_ = f'{os.path.splitext(audio_path)[0]}_good_sample.wav'
|
451 |
+
mono_path = os.path.join(file_directory, file_) # get root
|
452 |
+
|
453 |
+
command = f'ffmpeg -y -loglevel error -i "{audio_path}" -ac 1 -ar 22050 -sample_fmt s16 -f wav "{mono_path}"'
|
454 |
+
run_command(command)
|
455 |
+
|
456 |
+
return mono_path
|
457 |
+
|
458 |
+
|
459 |
+
def sanitize_file_name(file_name):
|
460 |
+
import unicodedata
|
461 |
+
|
462 |
+
# Normalize the string to NFKD form to separate combined characters into
|
463 |
+
# base characters and diacritics
|
464 |
+
normalized_name = unicodedata.normalize("NFKD", file_name)
|
465 |
+
# Replace any non-ASCII characters or special symbols with an underscore
|
466 |
+
sanitized_name = re.sub(r"[^\w\s.-]", "_", normalized_name)
|
467 |
+
return sanitized_name
|
468 |
+
|
469 |
+
|
470 |
+
def create_wav_file_vc(
|
471 |
+
sample_name="", # name final file
|
472 |
+
audio_wav="", # path
|
473 |
+
start=None, # trim start
|
474 |
+
end=None, # trim end
|
475 |
+
output_final_path="_XTTS_",
|
476 |
+
get_vocals_dereverb=True,
|
477 |
+
):
|
478 |
+
sample_name = sample_name if sample_name else "default_name"
|
479 |
+
sample_name = sanitize_file_name(sample_name)
|
480 |
+
audio_wav = audio_wav if isinstance(audio_wav, str) else audio_wav.name
|
481 |
+
|
482 |
+
BASE_DIR = (
|
483 |
+
"." # os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
484 |
+
)
|
485 |
+
|
486 |
+
output_dir = os.path.join(BASE_DIR, "clean_song_output") # remove content
|
487 |
+
# remove_directory_contents(output_dir)
|
488 |
+
|
489 |
+
if start or end:
|
490 |
+
# Cut file
|
491 |
+
audio_segment = audio_trimming(audio_wav, output_dir, start, end)
|
492 |
+
else:
|
493 |
+
# Complete file
|
494 |
+
audio_segment = audio_wav
|
495 |
+
|
496 |
+
from .mdx_net import process_uvr_task
|
497 |
+
|
498 |
+
try:
|
499 |
+
_, _, _, _, audio_segment = process_uvr_task(
|
500 |
+
orig_song_path=audio_segment,
|
501 |
+
main_vocals=True,
|
502 |
+
dereverb=get_vocals_dereverb,
|
503 |
+
)
|
504 |
+
except Exception as error:
|
505 |
+
logger.error(str(error))
|
506 |
+
|
507 |
+
sample = convert_to_xtts_good_sample(audio_segment)
|
508 |
+
|
509 |
+
sample_name = f"{sample_name}.wav"
|
510 |
+
sample_rename = rename_file(sample, sample_name)
|
511 |
+
|
512 |
+
copy_files(sample_rename, output_final_path)
|
513 |
+
|
514 |
+
final_sample = os.path.join(output_final_path, sample_name)
|
515 |
+
if os.path.exists(final_sample):
|
516 |
+
logger.info(final_sample)
|
517 |
+
return final_sample
|
518 |
+
else:
|
519 |
+
raise Exception(f"Error wav: {final_sample}")
|
520 |
+
|
521 |
+
|
522 |
+
def create_new_files_for_vc(
|
523 |
+
speakers_coqui,
|
524 |
+
segments_base,
|
525 |
+
dereverb_automatic=True
|
526 |
+
):
|
527 |
+
# before function delete automatic delete_previous_automatic
|
528 |
+
output_dir = os.path.join(".", "clean_song_output") # remove content
|
529 |
+
remove_directory_contents(output_dir)
|
530 |
+
|
531 |
+
for speaker in speakers_coqui:
|
532 |
+
filtered_speaker = [
|
533 |
+
segment
|
534 |
+
for segment in segments_base
|
535 |
+
if segment["speaker"] == speaker
|
536 |
+
]
|
537 |
+
if len(filtered_speaker) > 4:
|
538 |
+
filtered_speaker = filtered_speaker[1:]
|
539 |
+
if filtered_speaker[0]["tts_name"] == "_XTTS_/AUTOMATIC.wav":
|
540 |
+
name_automatic_wav = f"AUTOMATIC_{speaker}"
|
541 |
+
if os.path.exists(f"_XTTS_/{name_automatic_wav}.wav"):
|
542 |
+
logger.info(f"WAV automatic {speaker} exists")
|
543 |
+
# path_wav = path_automatic_wav
|
544 |
+
pass
|
545 |
+
else:
|
546 |
+
# create wav
|
547 |
+
wav_ok = False
|
548 |
+
for seg in filtered_speaker:
|
549 |
+
duration = float(seg["end"]) - float(seg["start"])
|
550 |
+
if duration > 7.0 and duration < 12.0:
|
551 |
+
logger.info(
|
552 |
+
f'Processing segment: {seg["start"]}, {seg["end"]}, {seg["speaker"]}, {duration}, {seg["text"]}'
|
553 |
+
)
|
554 |
+
create_wav_file_vc(
|
555 |
+
sample_name=name_automatic_wav,
|
556 |
+
audio_wav="audio.wav",
|
557 |
+
start=(float(seg["start"]) + 1.0),
|
558 |
+
end=(float(seg["end"]) - 1.0),
|
559 |
+
get_vocals_dereverb=dereverb_automatic,
|
560 |
+
)
|
561 |
+
wav_ok = True
|
562 |
+
break
|
563 |
+
|
564 |
+
if not wav_ok:
|
565 |
+
logger.info("Taking the first segment")
|
566 |
+
seg = filtered_speaker[0]
|
567 |
+
logger.info(
|
568 |
+
f'Processing segment: {seg["start"]}, {seg["end"]}, {seg["speaker"]}, {seg["text"]}'
|
569 |
+
)
|
570 |
+
max_duration = float(seg["end"]) - float(seg["start"])
|
571 |
+
max_duration = max(2.0, min(max_duration, 9.0))
|
572 |
+
|
573 |
+
create_wav_file_vc(
|
574 |
+
sample_name=name_automatic_wav,
|
575 |
+
audio_wav="audio.wav",
|
576 |
+
start=(float(seg["start"])),
|
577 |
+
end=(float(seg["start"]) + max_duration),
|
578 |
+
get_vocals_dereverb=dereverb_automatic,
|
579 |
+
)
|
580 |
+
|
581 |
+
|
582 |
+
def segments_coqui_tts(
|
583 |
+
filtered_coqui_segments,
|
584 |
+
TRANSLATE_AUDIO_TO,
|
585 |
+
model_id_coqui="tts_models/multilingual/multi-dataset/xtts_v2",
|
586 |
+
speakers_coqui=None,
|
587 |
+
delete_previous_automatic=True,
|
588 |
+
dereverb_automatic=True,
|
589 |
+
emotion=None,
|
590 |
+
):
|
591 |
+
"""XTTS
|
592 |
+
Install:
|
593 |
+
pip install -q TTS==0.21.1
|
594 |
+
pip install -q numpy==1.23.5
|
595 |
+
|
596 |
+
Notes:
|
597 |
+
- tts_name is the wav|mp3|ogg|m4a file for VC
|
598 |
+
"""
|
599 |
+
from TTS.api import TTS
|
600 |
+
|
601 |
+
TRANSLATE_AUDIO_TO = fix_code_language(TRANSLATE_AUDIO_TO, syntax="coqui")
|
602 |
+
supported_lang_coqui = [
|
603 |
+
"zh-cn",
|
604 |
+
"en",
|
605 |
+
"fr",
|
606 |
+
"de",
|
607 |
+
"it",
|
608 |
+
"pt",
|
609 |
+
"pl",
|
610 |
+
"tr",
|
611 |
+
"ru",
|
612 |
+
"nl",
|
613 |
+
"cs",
|
614 |
+
"ar",
|
615 |
+
"es",
|
616 |
+
"hu",
|
617 |
+
"ko",
|
618 |
+
"ja",
|
619 |
+
]
|
620 |
+
if TRANSLATE_AUDIO_TO not in supported_lang_coqui:
|
621 |
+
raise TTS_OperationError(
|
622 |
+
f"'{TRANSLATE_AUDIO_TO}' is not a supported language for Coqui XTTS"
|
623 |
+
)
|
624 |
+
# Emotion and speed can only be used with Coqui Studio models. discontinued
|
625 |
+
# emotions = ["Neutral", "Happy", "Sad", "Angry", "Dull"]
|
626 |
+
|
627 |
+
if delete_previous_automatic:
|
628 |
+
for spk in speakers_coqui:
|
629 |
+
remove_files(f"_XTTS_/AUTOMATIC_{spk}.wav")
|
630 |
+
|
631 |
+
directory_audios_vc = "_XTTS_"
|
632 |
+
create_directories(directory_audios_vc)
|
633 |
+
create_new_files_for_vc(
|
634 |
+
speakers_coqui,
|
635 |
+
filtered_coqui_segments["segments"],
|
636 |
+
dereverb_automatic,
|
637 |
+
)
|
638 |
+
|
639 |
+
# Init TTS
|
640 |
+
device = os.environ.get("SONITR_DEVICE")
|
641 |
+
model = TTS(model_id_coqui).to(device)
|
642 |
+
sampling_rate = 24000
|
643 |
+
|
644 |
+
# filtered_segments = filtered_coqui_segments['segments']
|
645 |
+
# Sorting the segments by 'tts_name'
|
646 |
+
# sorted_segments = sorted(filtered_segments, key=lambda x: x['tts_name'])
|
647 |
+
# logger.debug(sorted_segments)
|
648 |
+
|
649 |
+
for segment in tqdm(filtered_coqui_segments["segments"]):
|
650 |
+
speaker = segment["speaker"]
|
651 |
+
text = segment["text"]
|
652 |
+
start = segment["start"]
|
653 |
+
tts_name = segment["tts_name"]
|
654 |
+
if tts_name == "_XTTS_/AUTOMATIC.wav":
|
655 |
+
tts_name = f"_XTTS_/AUTOMATIC_{speaker}.wav"
|
656 |
+
|
657 |
+
# make the tts audio
|
658 |
+
filename = f"audio/{start}.ogg"
|
659 |
+
logger.info(f"{text} >> {filename}")
|
660 |
+
try:
|
661 |
+
# Infer
|
662 |
+
wav = model.tts(
|
663 |
+
text=text, speaker_wav=tts_name, language=TRANSLATE_AUDIO_TO
|
664 |
+
)
|
665 |
+
data_tts = pad_array(
|
666 |
+
wav,
|
667 |
+
sampling_rate,
|
668 |
+
)
|
669 |
+
# Save file
|
670 |
+
sf.write(
|
671 |
+
file=filename,
|
672 |
+
samplerate=sampling_rate,
|
673 |
+
data=data_tts,
|
674 |
+
format="ogg",
|
675 |
+
subtype="vorbis",
|
676 |
+
)
|
677 |
+
verify_saved_file_and_size(filename)
|
678 |
+
except Exception as error:
|
679 |
+
error_handling_in_tts(error, segment, TRANSLATE_AUDIO_TO, filename)
|
680 |
+
gc.collect()
|
681 |
+
torch.cuda.empty_cache()
|
682 |
+
try:
|
683 |
+
del model
|
684 |
+
gc.collect()
|
685 |
+
torch.cuda.empty_cache()
|
686 |
+
except Exception as error:
|
687 |
+
logger.error(str(error))
|
688 |
+
gc.collect()
|
689 |
+
torch.cuda.empty_cache()
|
690 |
+
|
691 |
+
|
692 |
+
# =====================================
|
693 |
+
# PIPER TTS
|
694 |
+
# =====================================
|
695 |
+
|
696 |
+
|
697 |
+
def piper_tts_voices_list():
|
698 |
+
file_path = download_manager(
|
699 |
+
url="https://huggingface.co/rhasspy/piper-voices/resolve/main/voices.json",
|
700 |
+
path="./PIPER_MODELS",
|
701 |
+
)
|
702 |
+
|
703 |
+
with open(file_path, "r", encoding="utf8") as file:
|
704 |
+
data = json.load(file)
|
705 |
+
piper_id_models = [key + " VITS-onnx" for key in data.keys()]
|
706 |
+
|
707 |
+
return piper_id_models
|
708 |
+
|
709 |
+
|
710 |
+
def replace_text_in_json(file_path, key_to_replace, new_text, condition=None):
|
711 |
+
# Read the JSON file
|
712 |
+
with open(file_path, "r", encoding="utf-8") as file:
|
713 |
+
data = json.load(file)
|
714 |
+
|
715 |
+
# Modify the specified key's value with the new text
|
716 |
+
if key_to_replace in data:
|
717 |
+
if condition:
|
718 |
+
value_condition = condition
|
719 |
+
else:
|
720 |
+
value_condition = data[key_to_replace]
|
721 |
+
|
722 |
+
if data[key_to_replace] == value_condition:
|
723 |
+
data[key_to_replace] = new_text
|
724 |
+
|
725 |
+
# Write the modified content back to the JSON file
|
726 |
+
with open(file_path, "w") as file:
|
727 |
+
json.dump(
|
728 |
+
data, file, indent=2
|
729 |
+
) # Write the modified data back to the file with indentation for readability
|
730 |
+
|
731 |
+
|
732 |
+
def load_piper_model(
|
733 |
+
model: str,
|
734 |
+
data_dir: list,
|
735 |
+
download_dir: str = "",
|
736 |
+
update_voices: bool = False,
|
737 |
+
):
|
738 |
+
from piper import PiperVoice
|
739 |
+
from piper.download import ensure_voice_exists, find_voice, get_voices
|
740 |
+
|
741 |
+
try:
|
742 |
+
import onnxruntime as rt
|
743 |
+
|
744 |
+
if rt.get_device() == "GPU" and os.environ.get("SONITR_DEVICE") == "cuda":
|
745 |
+
logger.debug("onnxruntime device > GPU")
|
746 |
+
cuda = True
|
747 |
+
else:
|
748 |
+
logger.info(
|
749 |
+
"onnxruntime device > CPU"
|
750 |
+
) # try pip install onnxruntime-gpu
|
751 |
+
cuda = False
|
752 |
+
except Exception as error:
|
753 |
+
raise TTS_OperationError(f"onnxruntime error: {str(error)}")
|
754 |
+
|
755 |
+
# Disable CUDA in Windows
|
756 |
+
if platform.system() == "Windows":
|
757 |
+
logger.info("Employing CPU exclusivity with Piper TTS")
|
758 |
+
cuda = False
|
759 |
+
|
760 |
+
if not download_dir:
|
761 |
+
# Download to first data directory by default
|
762 |
+
download_dir = data_dir[0]
|
763 |
+
else:
|
764 |
+
data_dir = [os.path.join(data_dir[0], download_dir)]
|
765 |
+
|
766 |
+
# Download voice if file doesn't exist
|
767 |
+
model_path = Path(model)
|
768 |
+
if not model_path.exists():
|
769 |
+
# Load voice info
|
770 |
+
voices_info = get_voices(download_dir, update_voices=update_voices)
|
771 |
+
|
772 |
+
# Resolve aliases for backwards compatibility with old voice names
|
773 |
+
aliases_info: Dict[str, Any] = {}
|
774 |
+
for voice_info in voices_info.values():
|
775 |
+
for voice_alias in voice_info.get("aliases", []):
|
776 |
+
aliases_info[voice_alias] = {"_is_alias": True, **voice_info}
|
777 |
+
|
778 |
+
voices_info.update(aliases_info)
|
779 |
+
ensure_voice_exists(model, data_dir, download_dir, voices_info)
|
780 |
+
model, config = find_voice(model, data_dir)
|
781 |
+
|
782 |
+
replace_text_in_json(
|
783 |
+
config, "phoneme_type", "espeak", "PhonemeType.ESPEAK"
|
784 |
+
)
|
785 |
+
|
786 |
+
# Load voice
|
787 |
+
voice = PiperVoice.load(model, config_path=config, use_cuda=cuda)
|
788 |
+
|
789 |
+
return voice
|
790 |
+
|
791 |
+
|
792 |
+
def synthesize_text_to_audio_np_array(voice, text, synthesize_args):
|
793 |
+
audio_stream = voice.synthesize_stream_raw(text, **synthesize_args)
|
794 |
+
|
795 |
+
# Collect the audio bytes into a single NumPy array
|
796 |
+
audio_data = b""
|
797 |
+
for audio_bytes in audio_stream:
|
798 |
+
audio_data += audio_bytes
|
799 |
+
|
800 |
+
# Ensure correct data type and convert audio bytes to NumPy array
|
801 |
+
audio_np = np.frombuffer(audio_data, dtype=np.int16)
|
802 |
+
return audio_np
|
803 |
+
|
804 |
+
|
805 |
+
def segments_vits_onnx_tts(filtered_onnx_vits_segments, TRANSLATE_AUDIO_TO):
|
806 |
+
"""
|
807 |
+
Install:
|
808 |
+
pip install -q piper-tts==1.2.0 onnxruntime-gpu # for cuda118
|
809 |
+
"""
|
810 |
+
|
811 |
+
data_dir = [
|
812 |
+
str(Path.cwd())
|
813 |
+
] # "Data directory to check for downloaded models (default: current directory)"
|
814 |
+
download_dir = "PIPER_MODELS"
|
815 |
+
# model_name = "en_US-lessac-medium" tts_name in a dict like VITS
|
816 |
+
update_voices = True # "Download latest voices.json during startup",
|
817 |
+
|
818 |
+
synthesize_args = {
|
819 |
+
"speaker_id": None,
|
820 |
+
"length_scale": 1.0,
|
821 |
+
"noise_scale": 0.667,
|
822 |
+
"noise_w": 0.8,
|
823 |
+
"sentence_silence": 0.0,
|
824 |
+
}
|
825 |
+
|
826 |
+
filtered_segments = filtered_onnx_vits_segments["segments"]
|
827 |
+
# Sorting the segments by 'tts_name'
|
828 |
+
sorted_segments = sorted(filtered_segments, key=lambda x: x["tts_name"])
|
829 |
+
logger.debug(sorted_segments)
|
830 |
+
|
831 |
+
model_name_key = None
|
832 |
+
for segment in tqdm(sorted_segments):
|
833 |
+
speaker = segment["speaker"] # noqa
|
834 |
+
text = segment["text"]
|
835 |
+
start = segment["start"]
|
836 |
+
tts_name = segment["tts_name"].replace(" VITS-onnx", "")
|
837 |
+
|
838 |
+
if tts_name != model_name_key:
|
839 |
+
model_name_key = tts_name
|
840 |
+
model = load_piper_model(
|
841 |
+
tts_name, data_dir, download_dir, update_voices
|
842 |
+
)
|
843 |
+
sampling_rate = model.config.sample_rate
|
844 |
+
|
845 |
+
# make the tts audio
|
846 |
+
filename = f"audio/{start}.ogg"
|
847 |
+
logger.info(f"{text} >> {filename}")
|
848 |
+
try:
|
849 |
+
# Infer
|
850 |
+
speech_output = synthesize_text_to_audio_np_array(
|
851 |
+
model, text, synthesize_args
|
852 |
+
)
|
853 |
+
data_tts = pad_array(
|
854 |
+
speech_output, # .cpu().numpy().squeeze().astype(np.float32),
|
855 |
+
sampling_rate,
|
856 |
+
)
|
857 |
+
# Save file
|
858 |
+
sf.write(
|
859 |
+
file=filename,
|
860 |
+
samplerate=sampling_rate,
|
861 |
+
data=data_tts,
|
862 |
+
format="ogg",
|
863 |
+
subtype="vorbis",
|
864 |
+
)
|
865 |
+
verify_saved_file_and_size(filename)
|
866 |
+
except Exception as error:
|
867 |
+
error_handling_in_tts(error, segment, TRANSLATE_AUDIO_TO, filename)
|
868 |
+
gc.collect()
|
869 |
+
torch.cuda.empty_cache()
|
870 |
+
try:
|
871 |
+
del model
|
872 |
+
gc.collect()
|
873 |
+
torch.cuda.empty_cache()
|
874 |
+
except Exception as error:
|
875 |
+
logger.error(str(error))
|
876 |
+
gc.collect()
|
877 |
+
torch.cuda.empty_cache()
|
878 |
+
|
879 |
+
|
880 |
+
# =====================================
|
881 |
+
# CLOSEAI TTS
|
882 |
+
# =====================================
|
883 |
+
|
884 |
+
|
885 |
+
def segments_openai_tts(
|
886 |
+
filtered_openai_tts_segments, TRANSLATE_AUDIO_TO
|
887 |
+
):
|
888 |
+
from openai import OpenAI
|
889 |
+
|
890 |
+
client = OpenAI()
|
891 |
+
sampling_rate = 24000
|
892 |
+
|
893 |
+
# filtered_segments = filtered_openai_tts_segments['segments']
|
894 |
+
# Sorting the segments by 'tts_name'
|
895 |
+
# sorted_segments = sorted(filtered_segments, key=lambda x: x['tts_name'])
|
896 |
+
|
897 |
+
for segment in tqdm(filtered_openai_tts_segments["segments"]):
|
898 |
+
speaker = segment["speaker"] # noqa
|
899 |
+
text = segment["text"].strip()
|
900 |
+
start = segment["start"]
|
901 |
+
tts_name = segment["tts_name"]
|
902 |
+
|
903 |
+
# make the tts audio
|
904 |
+
filename = f"audio/{start}.ogg"
|
905 |
+
logger.info(f"{text} >> {filename}")
|
906 |
+
|
907 |
+
try:
|
908 |
+
# Request
|
909 |
+
response = client.audio.speech.create(
|
910 |
+
model="tts-1-hd" if "HD" in tts_name else "tts-1",
|
911 |
+
voice=tts_name.split()[0][1:],
|
912 |
+
response_format="wav",
|
913 |
+
input=text
|
914 |
+
)
|
915 |
+
|
916 |
+
audio_bytes = b''
|
917 |
+
for data in response.iter_bytes(chunk_size=4096):
|
918 |
+
audio_bytes += data
|
919 |
+
|
920 |
+
speech_output = np.frombuffer(audio_bytes, dtype=np.int16)
|
921 |
+
|
922 |
+
# Save file
|
923 |
+
data_tts = pad_array(
|
924 |
+
speech_output[240:],
|
925 |
+
sampling_rate,
|
926 |
+
)
|
927 |
+
|
928 |
+
sf.write(
|
929 |
+
file=filename,
|
930 |
+
samplerate=sampling_rate,
|
931 |
+
data=data_tts,
|
932 |
+
format="ogg",
|
933 |
+
subtype="vorbis",
|
934 |
+
)
|
935 |
+
verify_saved_file_and_size(filename)
|
936 |
+
|
937 |
+
except Exception as error:
|
938 |
+
error_handling_in_tts(error, segment, TRANSLATE_AUDIO_TO, filename)
|
939 |
+
|
940 |
+
|
941 |
+
# =====================================
|
942 |
+
# Select task TTS
|
943 |
+
# =====================================
|
944 |
+
|
945 |
+
|
946 |
+
def find_spkr(pattern, speaker_to_voice, segments):
|
947 |
+
return [
|
948 |
+
speaker
|
949 |
+
for speaker, voice in speaker_to_voice.items()
|
950 |
+
if pattern.match(voice) and any(
|
951 |
+
segment["speaker"] == speaker for segment in segments
|
952 |
+
)
|
953 |
+
]
|
954 |
+
|
955 |
+
|
956 |
+
def filter_by_speaker(speakers, segments):
|
957 |
+
return {
|
958 |
+
"segments": [
|
959 |
+
segment
|
960 |
+
for segment in segments
|
961 |
+
if segment["speaker"] in speakers
|
962 |
+
]
|
963 |
+
}
|
964 |
+
|
965 |
+
|
966 |
+
def audio_segmentation_to_voice(
|
967 |
+
result_diarize,
|
968 |
+
TRANSLATE_AUDIO_TO,
|
969 |
+
is_gui,
|
970 |
+
tts_voice00,
|
971 |
+
tts_voice01="",
|
972 |
+
tts_voice02="",
|
973 |
+
tts_voice03="",
|
974 |
+
tts_voice04="",
|
975 |
+
tts_voice05="",
|
976 |
+
tts_voice06="",
|
977 |
+
tts_voice07="",
|
978 |
+
tts_voice08="",
|
979 |
+
tts_voice09="",
|
980 |
+
tts_voice10="",
|
981 |
+
tts_voice11="",
|
982 |
+
dereverb_automatic=True,
|
983 |
+
model_id_bark="suno/bark-small",
|
984 |
+
model_id_coqui="tts_models/multilingual/multi-dataset/xtts_v2",
|
985 |
+
delete_previous_automatic=True,
|
986 |
+
):
|
987 |
+
|
988 |
+
remove_directory_contents("audio")
|
989 |
+
|
990 |
+
# Mapping speakers to voice variables
|
991 |
+
speaker_to_voice = {
|
992 |
+
"SPEAKER_00": tts_voice00,
|
993 |
+
"SPEAKER_01": tts_voice01,
|
994 |
+
"SPEAKER_02": tts_voice02,
|
995 |
+
"SPEAKER_03": tts_voice03,
|
996 |
+
"SPEAKER_04": tts_voice04,
|
997 |
+
"SPEAKER_05": tts_voice05,
|
998 |
+
"SPEAKER_06": tts_voice06,
|
999 |
+
"SPEAKER_07": tts_voice07,
|
1000 |
+
"SPEAKER_08": tts_voice08,
|
1001 |
+
"SPEAKER_09": tts_voice09,
|
1002 |
+
"SPEAKER_10": tts_voice10,
|
1003 |
+
"SPEAKER_11": tts_voice11,
|
1004 |
+
}
|
1005 |
+
|
1006 |
+
# Assign 'SPEAKER_00' to segments without a 'speaker' key
|
1007 |
+
for segment in result_diarize["segments"]:
|
1008 |
+
if "speaker" not in segment:
|
1009 |
+
segment["speaker"] = "SPEAKER_00"
|
1010 |
+
logger.warning(
|
1011 |
+
"NO SPEAKER DETECT IN SEGMENT: First TTS will be used in the"
|
1012 |
+
f" segment time {segment['start'], segment['text']}"
|
1013 |
+
)
|
1014 |
+
# Assign the TTS name
|
1015 |
+
segment["tts_name"] = speaker_to_voice[segment["speaker"]]
|
1016 |
+
|
1017 |
+
# Find TTS method
|
1018 |
+
pattern_edge = re.compile(r".*-(Male|Female)$")
|
1019 |
+
pattern_bark = re.compile(r".* BARK$")
|
1020 |
+
pattern_vits = re.compile(r".* VITS$")
|
1021 |
+
pattern_coqui = re.compile(r".+\.(wav|mp3|ogg|m4a)$")
|
1022 |
+
pattern_vits_onnx = re.compile(r".* VITS-onnx$")
|
1023 |
+
pattern_openai_tts = re.compile(r".* OpenAI-TTS$")
|
1024 |
+
|
1025 |
+
all_segments = result_diarize["segments"]
|
1026 |
+
|
1027 |
+
speakers_edge = find_spkr(pattern_edge, speaker_to_voice, all_segments)
|
1028 |
+
speakers_bark = find_spkr(pattern_bark, speaker_to_voice, all_segments)
|
1029 |
+
speakers_vits = find_spkr(pattern_vits, speaker_to_voice, all_segments)
|
1030 |
+
speakers_coqui = find_spkr(pattern_coqui, speaker_to_voice, all_segments)
|
1031 |
+
speakers_vits_onnx = find_spkr(
|
1032 |
+
pattern_vits_onnx, speaker_to_voice, all_segments
|
1033 |
+
)
|
1034 |
+
speakers_openai_tts = find_spkr(
|
1035 |
+
pattern_openai_tts, speaker_to_voice, all_segments
|
1036 |
+
)
|
1037 |
+
|
1038 |
+
# Filter method in segments
|
1039 |
+
filtered_edge = filter_by_speaker(speakers_edge, all_segments)
|
1040 |
+
filtered_bark = filter_by_speaker(speakers_bark, all_segments)
|
1041 |
+
filtered_vits = filter_by_speaker(speakers_vits, all_segments)
|
1042 |
+
filtered_coqui = filter_by_speaker(speakers_coqui, all_segments)
|
1043 |
+
filtered_vits_onnx = filter_by_speaker(speakers_vits_onnx, all_segments)
|
1044 |
+
filtered_openai_tts = filter_by_speaker(speakers_openai_tts, all_segments)
|
1045 |
+
|
1046 |
+
# Infer
|
1047 |
+
if filtered_edge["segments"]:
|
1048 |
+
logger.info(f"EDGE TTS: {speakers_edge}")
|
1049 |
+
segments_egde_tts(filtered_edge, TRANSLATE_AUDIO_TO, is_gui) # mp3
|
1050 |
+
if filtered_bark["segments"]:
|
1051 |
+
logger.info(f"BARK TTS: {speakers_bark}")
|
1052 |
+
segments_bark_tts(
|
1053 |
+
filtered_bark, TRANSLATE_AUDIO_TO, model_id_bark
|
1054 |
+
) # wav
|
1055 |
+
if filtered_vits["segments"]:
|
1056 |
+
logger.info(f"VITS TTS: {speakers_vits}")
|
1057 |
+
segments_vits_tts(filtered_vits, TRANSLATE_AUDIO_TO) # wav
|
1058 |
+
if filtered_coqui["segments"]:
|
1059 |
+
logger.info(f"Coqui TTS: {speakers_coqui}")
|
1060 |
+
segments_coqui_tts(
|
1061 |
+
filtered_coqui,
|
1062 |
+
TRANSLATE_AUDIO_TO,
|
1063 |
+
model_id_coqui,
|
1064 |
+
speakers_coqui,
|
1065 |
+
delete_previous_automatic,
|
1066 |
+
dereverb_automatic,
|
1067 |
+
) # wav
|
1068 |
+
if filtered_vits_onnx["segments"]:
|
1069 |
+
logger.info(f"PIPER TTS: {speakers_vits_onnx}")
|
1070 |
+
segments_vits_onnx_tts(filtered_vits_onnx, TRANSLATE_AUDIO_TO) # wav
|
1071 |
+
if filtered_openai_tts["segments"]:
|
1072 |
+
logger.info(f"OpenAI TTS: {speakers_openai_tts}")
|
1073 |
+
segments_openai_tts(filtered_openai_tts, TRANSLATE_AUDIO_TO) # wav
|
1074 |
+
|
1075 |
+
[result.pop("tts_name", None) for result in result_diarize["segments"]]
|
1076 |
+
return [
|
1077 |
+
speakers_edge,
|
1078 |
+
speakers_bark,
|
1079 |
+
speakers_vits,
|
1080 |
+
speakers_coqui,
|
1081 |
+
speakers_vits_onnx,
|
1082 |
+
speakers_openai_tts
|
1083 |
+
]
|
1084 |
+
|
1085 |
+
|
1086 |
+
def accelerate_segments(
|
1087 |
+
result_diarize,
|
1088 |
+
max_accelerate_audio,
|
1089 |
+
valid_speakers,
|
1090 |
+
acceleration_rate_regulation=False,
|
1091 |
+
folder_output="audio2",
|
1092 |
+
):
|
1093 |
+
logger.info("Apply acceleration")
|
1094 |
+
|
1095 |
+
(
|
1096 |
+
speakers_edge,
|
1097 |
+
speakers_bark,
|
1098 |
+
speakers_vits,
|
1099 |
+
speakers_coqui,
|
1100 |
+
speakers_vits_onnx,
|
1101 |
+
speakers_openai_tts
|
1102 |
+
) = valid_speakers
|
1103 |
+
|
1104 |
+
create_directories(f"{folder_output}/audio/")
|
1105 |
+
remove_directory_contents(f"{folder_output}/audio/")
|
1106 |
+
|
1107 |
+
audio_files = []
|
1108 |
+
speakers_list = []
|
1109 |
+
|
1110 |
+
max_count_segments_idx = len(result_diarize["segments"]) - 1
|
1111 |
+
|
1112 |
+
for i, segment in tqdm(enumerate(result_diarize["segments"])):
|
1113 |
+
text = segment["text"] # noqa
|
1114 |
+
start = segment["start"]
|
1115 |
+
end = segment["end"]
|
1116 |
+
speaker = segment["speaker"]
|
1117 |
+
|
1118 |
+
# find name audio
|
1119 |
+
# if speaker in speakers_edge:
|
1120 |
+
filename = f"audio/{start}.ogg"
|
1121 |
+
# elif speaker in speakers_bark + speakers_vits + speakers_coqui + speakers_vits_onnx:
|
1122 |
+
# filename = f"audio/{start}.wav" # wav
|
1123 |
+
|
1124 |
+
# duration
|
1125 |
+
duration_true = end - start
|
1126 |
+
duration_tts = librosa.get_duration(filename=filename)
|
1127 |
+
|
1128 |
+
# Accelerate percentage
|
1129 |
+
acc_percentage = duration_tts / duration_true
|
1130 |
+
|
1131 |
+
# Smoth
|
1132 |
+
if acceleration_rate_regulation and acc_percentage >= 1.3:
|
1133 |
+
try:
|
1134 |
+
next_segment = result_diarize["segments"][
|
1135 |
+
min(max_count_segments_idx, i + 1)
|
1136 |
+
]
|
1137 |
+
next_start = next_segment["start"]
|
1138 |
+
next_speaker = next_segment["speaker"]
|
1139 |
+
duration_with_next_start = next_start - start
|
1140 |
+
|
1141 |
+
if duration_with_next_start > duration_true:
|
1142 |
+
extra_time = duration_with_next_start - duration_true
|
1143 |
+
|
1144 |
+
if speaker == next_speaker:
|
1145 |
+
# half
|
1146 |
+
smoth_duration = duration_true + (extra_time * 0.5)
|
1147 |
+
else:
|
1148 |
+
# 7/10
|
1149 |
+
smoth_duration = duration_true + (extra_time * 0.7)
|
1150 |
+
logger.debug(
|
1151 |
+
f"Base acc: {acc_percentage}, "
|
1152 |
+
f"smoth acc: {duration_tts / smoth_duration}"
|
1153 |
+
)
|
1154 |
+
acc_percentage = max(1.2, (duration_tts / smoth_duration))
|
1155 |
+
|
1156 |
+
except Exception as error:
|
1157 |
+
logger.error(str(error))
|
1158 |
+
|
1159 |
+
if acc_percentage > max_accelerate_audio:
|
1160 |
+
acc_percentage = max_accelerate_audio
|
1161 |
+
elif acc_percentage <= 1.15 and acc_percentage >= 0.8:
|
1162 |
+
acc_percentage = 1.0
|
1163 |
+
elif acc_percentage <= 0.79:
|
1164 |
+
acc_percentage = 0.8
|
1165 |
+
|
1166 |
+
# Round
|
1167 |
+
acc_percentage = round(acc_percentage + 0.0, 1)
|
1168 |
+
|
1169 |
+
# Format read if need
|
1170 |
+
if speaker in speakers_edge:
|
1171 |
+
info_enc = sf.info(filename).format
|
1172 |
+
else:
|
1173 |
+
info_enc = "OGG"
|
1174 |
+
|
1175 |
+
# Apply aceleration or opposite to the audio file in folder_output folder
|
1176 |
+
if acc_percentage == 1.0 and info_enc == "OGG":
|
1177 |
+
copy_files(filename, f"{folder_output}{os.sep}audio")
|
1178 |
+
else:
|
1179 |
+
os.system(
|
1180 |
+
f"ffmpeg -y -loglevel panic -i {filename} -filter:a atempo={acc_percentage} {folder_output}/{filename}"
|
1181 |
+
)
|
1182 |
+
|
1183 |
+
if logger.isEnabledFor(logging.DEBUG):
|
1184 |
+
duration_create = librosa.get_duration(
|
1185 |
+
filename=f"{folder_output}/{filename}"
|
1186 |
+
)
|
1187 |
+
logger.debug(
|
1188 |
+
f"acc_percen is {acc_percentage}, tts duration "
|
1189 |
+
f"is {duration_tts}, new duration is {duration_create}"
|
1190 |
+
f", for {filename}"
|
1191 |
+
)
|
1192 |
+
|
1193 |
+
audio_files.append(f"{folder_output}/{filename}")
|
1194 |
+
speaker = "TTS Speaker {:02d}".format(int(speaker[-2:]) + 1)
|
1195 |
+
speakers_list.append(speaker)
|
1196 |
+
|
1197 |
+
return audio_files, speakers_list
|
1198 |
+
|
1199 |
+
|
1200 |
+
# =====================================
|
1201 |
+
# Tone color converter
|
1202 |
+
# =====================================
|
1203 |
+
|
1204 |
+
|
1205 |
+
def se_process_audio_segments(
|
1206 |
+
source_seg, tone_color_converter, device, remove_previous_processed=True
|
1207 |
+
):
|
1208 |
+
# list wav seg
|
1209 |
+
source_audio_segs = glob.glob(f"{source_seg}/*.wav")
|
1210 |
+
if not source_audio_segs:
|
1211 |
+
raise ValueError(
|
1212 |
+
f"No audio segments found in {str(source_audio_segs)}"
|
1213 |
+
)
|
1214 |
+
|
1215 |
+
source_se_path = os.path.join(source_seg, "se.pth")
|
1216 |
+
|
1217 |
+
# if exist not create wav
|
1218 |
+
if os.path.isfile(source_se_path):
|
1219 |
+
se = torch.load(source_se_path).to(device)
|
1220 |
+
logger.debug(f"Previous created {source_se_path}")
|
1221 |
+
else:
|
1222 |
+
se = tone_color_converter.extract_se(source_audio_segs, source_se_path)
|
1223 |
+
|
1224 |
+
return se
|
1225 |
+
|
1226 |
+
|
1227 |
+
def create_wav_vc(
|
1228 |
+
valid_speakers,
|
1229 |
+
segments_base,
|
1230 |
+
audio_name,
|
1231 |
+
max_segments=10,
|
1232 |
+
target_dir="processed",
|
1233 |
+
get_vocals_dereverb=False,
|
1234 |
+
):
|
1235 |
+
# valid_speakers = list({item['speaker'] for item in segments_base})
|
1236 |
+
|
1237 |
+
# Before function delete automatic delete_previous_automatic
|
1238 |
+
output_dir = os.path.join(".", target_dir) # remove content
|
1239 |
+
# remove_directory_contents(output_dir)
|
1240 |
+
|
1241 |
+
path_source_segments = []
|
1242 |
+
path_target_segments = []
|
1243 |
+
for speaker in valid_speakers:
|
1244 |
+
filtered_speaker = [
|
1245 |
+
segment
|
1246 |
+
for segment in segments_base
|
1247 |
+
if segment["speaker"] == speaker
|
1248 |
+
]
|
1249 |
+
if len(filtered_speaker) > 4:
|
1250 |
+
filtered_speaker = filtered_speaker[1:]
|
1251 |
+
|
1252 |
+
dir_name_speaker = speaker + audio_name
|
1253 |
+
dir_name_speaker_tts = "tts" + speaker + audio_name
|
1254 |
+
dir_path_speaker = os.path.join(output_dir, dir_name_speaker)
|
1255 |
+
dir_path_speaker_tts = os.path.join(output_dir, dir_name_speaker_tts)
|
1256 |
+
create_directories([dir_path_speaker, dir_path_speaker_tts])
|
1257 |
+
|
1258 |
+
path_target_segments.append(dir_path_speaker)
|
1259 |
+
path_source_segments.append(dir_path_speaker_tts)
|
1260 |
+
|
1261 |
+
# create wav
|
1262 |
+
max_segments_count = 0
|
1263 |
+
for seg in filtered_speaker:
|
1264 |
+
duration = float(seg["end"]) - float(seg["start"])
|
1265 |
+
if duration > 3.0 and duration < 18.0:
|
1266 |
+
logger.info(
|
1267 |
+
f'Processing segment: {seg["start"]}, {seg["end"]}, {seg["speaker"]}, {duration}, {seg["text"]}'
|
1268 |
+
)
|
1269 |
+
name_new_wav = str(seg["start"])
|
1270 |
+
|
1271 |
+
check_segment_audio_target_file = os.path.join(
|
1272 |
+
dir_path_speaker, f"{name_new_wav}.wav"
|
1273 |
+
)
|
1274 |
+
|
1275 |
+
if os.path.exists(check_segment_audio_target_file):
|
1276 |
+
logger.debug(
|
1277 |
+
"Segment vc source exists: "
|
1278 |
+
f"{check_segment_audio_target_file}"
|
1279 |
+
)
|
1280 |
+
pass
|
1281 |
+
else:
|
1282 |
+
create_wav_file_vc(
|
1283 |
+
sample_name=name_new_wav,
|
1284 |
+
audio_wav="audio.wav",
|
1285 |
+
start=(float(seg["start"]) + 1.0),
|
1286 |
+
end=(float(seg["end"]) - 1.0),
|
1287 |
+
output_final_path=dir_path_speaker,
|
1288 |
+
get_vocals_dereverb=get_vocals_dereverb,
|
1289 |
+
)
|
1290 |
+
|
1291 |
+
file_name_tts = f"audio2/audio/{str(seg['start'])}.ogg"
|
1292 |
+
# copy_files(file_name_tts, os.path.join(output_dir, dir_name_speaker_tts)
|
1293 |
+
convert_to_xtts_good_sample(
|
1294 |
+
file_name_tts, dir_path_speaker_tts
|
1295 |
+
)
|
1296 |
+
|
1297 |
+
max_segments_count += 1
|
1298 |
+
if max_segments_count == max_segments:
|
1299 |
+
break
|
1300 |
+
|
1301 |
+
if max_segments_count == 0:
|
1302 |
+
logger.info("Taking the first segment")
|
1303 |
+
seg = filtered_speaker[0]
|
1304 |
+
logger.info(
|
1305 |
+
f'Processing segment: {seg["start"]}, {seg["end"]}, {seg["speaker"]}, {seg["text"]}'
|
1306 |
+
)
|
1307 |
+
max_duration = float(seg["end"]) - float(seg["start"])
|
1308 |
+
max_duration = max(1.0, min(max_duration, 18.0))
|
1309 |
+
|
1310 |
+
name_new_wav = str(seg["start"])
|
1311 |
+
create_wav_file_vc(
|
1312 |
+
sample_name=name_new_wav,
|
1313 |
+
audio_wav="audio.wav",
|
1314 |
+
start=(float(seg["start"])),
|
1315 |
+
end=(float(seg["start"]) + max_duration),
|
1316 |
+
output_final_path=dir_path_speaker,
|
1317 |
+
get_vocals_dereverb=get_vocals_dereverb,
|
1318 |
+
)
|
1319 |
+
|
1320 |
+
file_name_tts = f"audio2/audio/{str(seg['start'])}.ogg"
|
1321 |
+
# copy_files(file_name_tts, os.path.join(output_dir, dir_name_speaker_tts)
|
1322 |
+
convert_to_xtts_good_sample(file_name_tts, dir_path_speaker_tts)
|
1323 |
+
|
1324 |
+
logger.debug(f"Base: {str(path_source_segments)}")
|
1325 |
+
logger.debug(f"Target: {str(path_target_segments)}")
|
1326 |
+
|
1327 |
+
return path_source_segments, path_target_segments
|
1328 |
+
|
1329 |
+
|
1330 |
+
def toneconverter_openvoice(
|
1331 |
+
result_diarize,
|
1332 |
+
preprocessor_max_segments,
|
1333 |
+
remove_previous_process=True,
|
1334 |
+
get_vocals_dereverb=False,
|
1335 |
+
model="openvoice",
|
1336 |
+
):
|
1337 |
+
audio_path = "audio.wav"
|
1338 |
+
# se_path = "se.pth"
|
1339 |
+
target_dir = "processed"
|
1340 |
+
create_directories(target_dir)
|
1341 |
+
|
1342 |
+
from openvoice import se_extractor
|
1343 |
+
from openvoice.api import ToneColorConverter
|
1344 |
+
|
1345 |
+
audio_name = f"{os.path.basename(audio_path).rsplit('.', 1)[0]}_{se_extractor.hash_numpy_array(audio_path)}"
|
1346 |
+
# se_path = os.path.join(target_dir, audio_name, 'se.pth')
|
1347 |
+
|
1348 |
+
# create wav seg original and target
|
1349 |
+
|
1350 |
+
valid_speakers = list(
|
1351 |
+
{item["speaker"] for item in result_diarize["segments"]}
|
1352 |
+
)
|
1353 |
+
|
1354 |
+
logger.info("Openvoice preprocessor...")
|
1355 |
+
|
1356 |
+
if remove_previous_process:
|
1357 |
+
remove_directory_contents(target_dir)
|
1358 |
+
|
1359 |
+
path_source_segments, path_target_segments = create_wav_vc(
|
1360 |
+
valid_speakers,
|
1361 |
+
result_diarize["segments"],
|
1362 |
+
audio_name,
|
1363 |
+
max_segments=preprocessor_max_segments,
|
1364 |
+
get_vocals_dereverb=get_vocals_dereverb,
|
1365 |
+
)
|
1366 |
+
|
1367 |
+
logger.info("Openvoice loading model...")
|
1368 |
+
model_path_openvoice = "./OPENVOICE_MODELS"
|
1369 |
+
url_model_openvoice = "https://huggingface.co/myshell-ai/OpenVoice/resolve/main/checkpoints/converter"
|
1370 |
+
|
1371 |
+
if "v2" in model:
|
1372 |
+
model_path = os.path.join(model_path_openvoice, "v2")
|
1373 |
+
url_model_openvoice = url_model_openvoice.replace(
|
1374 |
+
"OpenVoice", "OpenVoiceV2"
|
1375 |
+
).replace("checkpoints/", "")
|
1376 |
+
else:
|
1377 |
+
model_path = os.path.join(model_path_openvoice, "v1")
|
1378 |
+
create_directories(model_path)
|
1379 |
+
|
1380 |
+
config_url = f"{url_model_openvoice}/config.json"
|
1381 |
+
checkpoint_url = f"{url_model_openvoice}/checkpoint.pth"
|
1382 |
+
|
1383 |
+
config_path = download_manager(url=config_url, path=model_path)
|
1384 |
+
checkpoint_path = download_manager(
|
1385 |
+
url=checkpoint_url, path=model_path
|
1386 |
+
)
|
1387 |
+
|
1388 |
+
device = os.environ.get("SONITR_DEVICE")
|
1389 |
+
tone_color_converter = ToneColorConverter(config_path, device=device)
|
1390 |
+
tone_color_converter.load_ckpt(checkpoint_path)
|
1391 |
+
|
1392 |
+
logger.info("Openvoice tone color converter:")
|
1393 |
+
global_progress_bar = tqdm(total=len(result_diarize["segments"]), desc="Progress")
|
1394 |
+
|
1395 |
+
for source_seg, target_seg, speaker in zip(
|
1396 |
+
path_source_segments, path_target_segments, valid_speakers
|
1397 |
+
):
|
1398 |
+
# source_se_path = os.path.join(source_seg, 'se.pth')
|
1399 |
+
source_se = se_process_audio_segments(source_seg, tone_color_converter, device)
|
1400 |
+
# target_se_path = os.path.join(target_seg, 'se.pth')
|
1401 |
+
target_se = se_process_audio_segments(target_seg, tone_color_converter, device)
|
1402 |
+
|
1403 |
+
# Iterate throw segments
|
1404 |
+
encode_message = "@MyShell"
|
1405 |
+
filtered_speaker = [
|
1406 |
+
segment
|
1407 |
+
for segment in result_diarize["segments"]
|
1408 |
+
if segment["speaker"] == speaker
|
1409 |
+
]
|
1410 |
+
for seg in filtered_speaker:
|
1411 |
+
src_path = (
|
1412 |
+
save_path
|
1413 |
+
) = f"audio2/audio/{str(seg['start'])}.ogg" # overwrite
|
1414 |
+
logger.debug(f"{src_path}")
|
1415 |
+
|
1416 |
+
tone_color_converter.convert(
|
1417 |
+
audio_src_path=src_path,
|
1418 |
+
src_se=source_se,
|
1419 |
+
tgt_se=target_se,
|
1420 |
+
output_path=save_path,
|
1421 |
+
message=encode_message,
|
1422 |
+
)
|
1423 |
+
|
1424 |
+
global_progress_bar.update(1)
|
1425 |
+
|
1426 |
+
global_progress_bar.close()
|
1427 |
+
|
1428 |
+
try:
|
1429 |
+
del tone_color_converter
|
1430 |
+
gc.collect()
|
1431 |
+
torch.cuda.empty_cache()
|
1432 |
+
except Exception as error:
|
1433 |
+
logger.error(str(error))
|
1434 |
+
gc.collect()
|
1435 |
+
torch.cuda.empty_cache()
|
1436 |
+
|
1437 |
+
|
1438 |
+
def toneconverter_freevc(
|
1439 |
+
result_diarize,
|
1440 |
+
remove_previous_process=True,
|
1441 |
+
get_vocals_dereverb=False,
|
1442 |
+
):
|
1443 |
+
audio_path = "audio.wav"
|
1444 |
+
target_dir = "processed"
|
1445 |
+
create_directories(target_dir)
|
1446 |
+
|
1447 |
+
from openvoice import se_extractor
|
1448 |
+
|
1449 |
+
audio_name = f"{os.path.basename(audio_path).rsplit('.', 1)[0]}_{se_extractor.hash_numpy_array(audio_path)}"
|
1450 |
+
|
1451 |
+
# create wav seg; original is target and dubbing is source
|
1452 |
+
valid_speakers = list(
|
1453 |
+
{item["speaker"] for item in result_diarize["segments"]}
|
1454 |
+
)
|
1455 |
+
|
1456 |
+
logger.info("FreeVC preprocessor...")
|
1457 |
+
|
1458 |
+
if remove_previous_process:
|
1459 |
+
remove_directory_contents(target_dir)
|
1460 |
+
|
1461 |
+
path_source_segments, path_target_segments = create_wav_vc(
|
1462 |
+
valid_speakers,
|
1463 |
+
result_diarize["segments"],
|
1464 |
+
audio_name,
|
1465 |
+
max_segments=1,
|
1466 |
+
get_vocals_dereverb=get_vocals_dereverb,
|
1467 |
+
)
|
1468 |
+
|
1469 |
+
logger.info("FreeVC loading model...")
|
1470 |
+
device_id = os.environ.get("SONITR_DEVICE")
|
1471 |
+
device = None if device_id == "cpu" else device_id
|
1472 |
+
try:
|
1473 |
+
from TTS.api import TTS
|
1474 |
+
tts = TTS(
|
1475 |
+
model_name="voice_conversion_models/multilingual/vctk/freevc24",
|
1476 |
+
progress_bar=False
|
1477 |
+
).to(device)
|
1478 |
+
except Exception as error:
|
1479 |
+
logger.error(str(error))
|
1480 |
+
logger.error("Error loading the FreeVC model.")
|
1481 |
+
return
|
1482 |
+
|
1483 |
+
logger.info("FreeVC process:")
|
1484 |
+
global_progress_bar = tqdm(total=len(result_diarize["segments"]), desc="Progress")
|
1485 |
+
|
1486 |
+
for source_seg, target_seg, speaker in zip(
|
1487 |
+
path_source_segments, path_target_segments, valid_speakers
|
1488 |
+
):
|
1489 |
+
|
1490 |
+
filtered_speaker = [
|
1491 |
+
segment
|
1492 |
+
for segment in result_diarize["segments"]
|
1493 |
+
if segment["speaker"] == speaker
|
1494 |
+
]
|
1495 |
+
|
1496 |
+
files_and_directories = os.listdir(target_seg)
|
1497 |
+
wav_files = [file for file in files_and_directories if file.endswith(".wav")]
|
1498 |
+
original_wav_audio_segment = os.path.join(target_seg, wav_files[0])
|
1499 |
+
|
1500 |
+
for seg in filtered_speaker:
|
1501 |
+
|
1502 |
+
src_path = (
|
1503 |
+
save_path
|
1504 |
+
) = f"audio2/audio/{str(seg['start'])}.ogg" # overwrite
|
1505 |
+
logger.debug(f"{src_path} - {original_wav_audio_segment}")
|
1506 |
+
|
1507 |
+
wav = tts.voice_conversion(
|
1508 |
+
source_wav=src_path,
|
1509 |
+
target_wav=original_wav_audio_segment,
|
1510 |
+
)
|
1511 |
+
|
1512 |
+
sf.write(
|
1513 |
+
file=save_path,
|
1514 |
+
samplerate=tts.voice_converter.vc_config.audio.output_sample_rate,
|
1515 |
+
data=wav,
|
1516 |
+
format="ogg",
|
1517 |
+
subtype="vorbis",
|
1518 |
+
)
|
1519 |
+
|
1520 |
+
global_progress_bar.update(1)
|
1521 |
+
|
1522 |
+
global_progress_bar.close()
|
1523 |
+
|
1524 |
+
try:
|
1525 |
+
del tts
|
1526 |
+
gc.collect()
|
1527 |
+
torch.cuda.empty_cache()
|
1528 |
+
except Exception as error:
|
1529 |
+
logger.error(str(error))
|
1530 |
+
gc.collect()
|
1531 |
+
torch.cuda.empty_cache()
|
1532 |
+
|
1533 |
+
|
1534 |
+
def toneconverter(
|
1535 |
+
result_diarize,
|
1536 |
+
preprocessor_max_segments,
|
1537 |
+
remove_previous_process=True,
|
1538 |
+
get_vocals_dereverb=False,
|
1539 |
+
method_vc="freevc"
|
1540 |
+
):
|
1541 |
+
|
1542 |
+
if method_vc == "freevc":
|
1543 |
+
if preprocessor_max_segments > 1:
|
1544 |
+
logger.info("FreeVC only uses one segment.")
|
1545 |
+
return toneconverter_freevc(
|
1546 |
+
result_diarize,
|
1547 |
+
remove_previous_process=remove_previous_process,
|
1548 |
+
get_vocals_dereverb=get_vocals_dereverb,
|
1549 |
+
)
|
1550 |
+
elif "openvoice" in method_vc:
|
1551 |
+
return toneconverter_openvoice(
|
1552 |
+
result_diarize,
|
1553 |
+
preprocessor_max_segments,
|
1554 |
+
remove_previous_process=remove_previous_process,
|
1555 |
+
get_vocals_dereverb=get_vocals_dereverb,
|
1556 |
+
model=method_vc,
|
1557 |
+
)
|
1558 |
+
|
1559 |
+
|
1560 |
+
if __name__ == "__main__":
|
1561 |
+
from segments import result_diarize
|
1562 |
+
|
1563 |
+
audio_segmentation_to_voice(
|
1564 |
+
result_diarize,
|
1565 |
+
TRANSLATE_AUDIO_TO="en",
|
1566 |
+
max_accelerate_audio=2.1,
|
1567 |
+
is_gui=True,
|
1568 |
+
tts_voice00="en-facebook-mms VITS",
|
1569 |
+
tts_voice01="en-CA-ClaraNeural-Female",
|
1570 |
+
tts_voice02="en-GB-ThomasNeural-Male",
|
1571 |
+
tts_voice03="en-GB-SoniaNeural-Female",
|
1572 |
+
tts_voice04="en-NZ-MitchellNeural-Male",
|
1573 |
+
tts_voice05="en-GB-MaisieNeural-Female",
|
1574 |
+
)
|
soni_translate/translate_segments.py
ADDED
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm import tqdm
|
2 |
+
from deep_translator import GoogleTranslator
|
3 |
+
from itertools import chain
|
4 |
+
import copy
|
5 |
+
from .language_configuration import fix_code_language, INVERTED_LANGUAGES
|
6 |
+
from .logging_setup import logger
|
7 |
+
import re
|
8 |
+
import json
|
9 |
+
import time
|
10 |
+
|
11 |
+
TRANSLATION_PROCESS_OPTIONS = [
|
12 |
+
"google_translator_batch",
|
13 |
+
"google_translator",
|
14 |
+
"gpt-3.5-turbo-0125_batch",
|
15 |
+
"gpt-3.5-turbo-0125",
|
16 |
+
"gpt-4-turbo-preview_batch",
|
17 |
+
"gpt-4-turbo-preview",
|
18 |
+
"disable_translation",
|
19 |
+
]
|
20 |
+
DOCS_TRANSLATION_PROCESS_OPTIONS = [
|
21 |
+
"google_translator",
|
22 |
+
"gpt-3.5-turbo-0125",
|
23 |
+
"gpt-4-turbo-preview",
|
24 |
+
"disable_translation",
|
25 |
+
]
|
26 |
+
|
27 |
+
|
28 |
+
def translate_iterative(segments, target, source=None):
|
29 |
+
"""
|
30 |
+
Translate text segments individually to the specified language.
|
31 |
+
|
32 |
+
Parameters:
|
33 |
+
- segments (list): A list of dictionaries with 'text' as a key for
|
34 |
+
segment text.
|
35 |
+
- target (str): Target language code.
|
36 |
+
- source (str, optional): Source language code. Defaults to None.
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
- list: Translated text segments in the target language.
|
40 |
+
|
41 |
+
Notes:
|
42 |
+
- Translates each segment using Google Translate.
|
43 |
+
|
44 |
+
Example:
|
45 |
+
segments = [{'text': 'first segment.'}, {'text': 'second segment.'}]
|
46 |
+
translated_segments = translate_iterative(segments, 'es')
|
47 |
+
"""
|
48 |
+
|
49 |
+
segments_ = copy.deepcopy(segments)
|
50 |
+
|
51 |
+
if (
|
52 |
+
not source
|
53 |
+
):
|
54 |
+
logger.debug("No source language")
|
55 |
+
source = "auto"
|
56 |
+
|
57 |
+
translator = GoogleTranslator(source=source, target=target)
|
58 |
+
|
59 |
+
for line in tqdm(range(len(segments_))):
|
60 |
+
text = segments_[line]["text"]
|
61 |
+
translated_line = translator.translate(text.strip())
|
62 |
+
segments_[line]["text"] = translated_line
|
63 |
+
|
64 |
+
return segments_
|
65 |
+
|
66 |
+
|
67 |
+
def verify_translate(
|
68 |
+
segments,
|
69 |
+
segments_copy,
|
70 |
+
translated_lines,
|
71 |
+
target,
|
72 |
+
source
|
73 |
+
):
|
74 |
+
"""
|
75 |
+
Verify integrity and translate segments if lengths match, otherwise
|
76 |
+
switch to iterative translation.
|
77 |
+
"""
|
78 |
+
if len(segments) == len(translated_lines):
|
79 |
+
for line in range(len(segments_copy)):
|
80 |
+
logger.debug(
|
81 |
+
f"{segments_copy[line]['text']} >> "
|
82 |
+
f"{translated_lines[line].strip()}"
|
83 |
+
)
|
84 |
+
segments_copy[line]["text"] = translated_lines[
|
85 |
+
line].replace("\t", "").replace("\n", "").strip()
|
86 |
+
return segments_copy
|
87 |
+
else:
|
88 |
+
logger.error(
|
89 |
+
"The translation failed, switching to google_translate iterative. "
|
90 |
+
f"{len(segments), len(translated_lines)}"
|
91 |
+
)
|
92 |
+
return translate_iterative(segments, target, source)
|
93 |
+
|
94 |
+
|
95 |
+
def translate_batch(segments, target, chunk_size=2000, source=None):
|
96 |
+
"""
|
97 |
+
Translate a batch of text segments into the specified language in chunks,
|
98 |
+
respecting the character limit.
|
99 |
+
|
100 |
+
Parameters:
|
101 |
+
- segments (list): List of dictionaries with 'text' as a key for segment
|
102 |
+
text.
|
103 |
+
- target (str): Target language code.
|
104 |
+
- chunk_size (int, optional): Maximum character limit for each translation
|
105 |
+
chunk (default is 2000; max 5000).
|
106 |
+
- source (str, optional): Source language code. Defaults to None.
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
- list: Translated text segments in the target language.
|
110 |
+
|
111 |
+
Notes:
|
112 |
+
- Splits input segments into chunks respecting the character limit for
|
113 |
+
translation.
|
114 |
+
- Translates the chunks using Google Translate.
|
115 |
+
- If chunked translation fails, switches to iterative translation using
|
116 |
+
`translate_iterative()`.
|
117 |
+
|
118 |
+
Example:
|
119 |
+
segments = [{'text': 'first segment.'}, {'text': 'second segment.'}]
|
120 |
+
translated = translate_batch(segments, 'es', chunk_size=4000, source='en')
|
121 |
+
"""
|
122 |
+
|
123 |
+
segments_copy = copy.deepcopy(segments)
|
124 |
+
|
125 |
+
if (
|
126 |
+
not source
|
127 |
+
):
|
128 |
+
logger.debug("No source language")
|
129 |
+
source = "auto"
|
130 |
+
|
131 |
+
# Get text
|
132 |
+
text_lines = []
|
133 |
+
for line in range(len(segments_copy)):
|
134 |
+
text = segments_copy[line]["text"].strip()
|
135 |
+
text_lines.append(text)
|
136 |
+
|
137 |
+
# chunk limit
|
138 |
+
text_merge = []
|
139 |
+
actual_chunk = ""
|
140 |
+
global_text_list = []
|
141 |
+
actual_text_list = []
|
142 |
+
for one_line in text_lines:
|
143 |
+
one_line = " " if not one_line else one_line
|
144 |
+
if (len(actual_chunk) + len(one_line)) <= chunk_size:
|
145 |
+
if actual_chunk:
|
146 |
+
actual_chunk += " ||||| "
|
147 |
+
actual_chunk += one_line
|
148 |
+
actual_text_list.append(one_line)
|
149 |
+
else:
|
150 |
+
text_merge.append(actual_chunk)
|
151 |
+
actual_chunk = one_line
|
152 |
+
global_text_list.append(actual_text_list)
|
153 |
+
actual_text_list = [one_line]
|
154 |
+
if actual_chunk:
|
155 |
+
text_merge.append(actual_chunk)
|
156 |
+
global_text_list.append(actual_text_list)
|
157 |
+
|
158 |
+
# translate chunks
|
159 |
+
progress_bar = tqdm(total=len(segments), desc="Translating")
|
160 |
+
translator = GoogleTranslator(source=source, target=target)
|
161 |
+
split_list = []
|
162 |
+
try:
|
163 |
+
for text, text_iterable in zip(text_merge, global_text_list):
|
164 |
+
translated_line = translator.translate(text.strip())
|
165 |
+
split_text = translated_line.split("|||||")
|
166 |
+
if len(split_text) == len(text_iterable):
|
167 |
+
progress_bar.update(len(split_text))
|
168 |
+
else:
|
169 |
+
logger.debug(
|
170 |
+
"Chunk fixing iteratively. Len chunk: "
|
171 |
+
f"{len(split_text)}, expected: {len(text_iterable)}"
|
172 |
+
)
|
173 |
+
split_text = []
|
174 |
+
for txt_iter in text_iterable:
|
175 |
+
translated_txt = translator.translate(txt_iter.strip())
|
176 |
+
split_text.append(translated_txt)
|
177 |
+
progress_bar.update(1)
|
178 |
+
split_list.append(split_text)
|
179 |
+
progress_bar.close()
|
180 |
+
except Exception as error:
|
181 |
+
progress_bar.close()
|
182 |
+
logger.error(str(error))
|
183 |
+
logger.warning(
|
184 |
+
"The translation in chunks failed, switching to iterative."
|
185 |
+
" Related: too many request"
|
186 |
+
) # use proxy or less chunk size
|
187 |
+
return translate_iterative(segments, target, source)
|
188 |
+
|
189 |
+
# un chunk
|
190 |
+
translated_lines = list(chain.from_iterable(split_list))
|
191 |
+
|
192 |
+
return verify_translate(
|
193 |
+
segments, segments_copy, translated_lines, target, source
|
194 |
+
)
|
195 |
+
|
196 |
+
|
197 |
+
def call_gpt_translate(
|
198 |
+
client,
|
199 |
+
model,
|
200 |
+
system_prompt,
|
201 |
+
user_prompt,
|
202 |
+
original_text=None,
|
203 |
+
batch_lines=None,
|
204 |
+
):
|
205 |
+
|
206 |
+
# https://platform.openai.com/docs/guides/text-generation/json-mode
|
207 |
+
response = client.chat.completions.create(
|
208 |
+
model=model,
|
209 |
+
response_format={"type": "json_object"},
|
210 |
+
messages=[
|
211 |
+
{"role": "system", "content": system_prompt},
|
212 |
+
{"role": "user", "content": user_prompt}
|
213 |
+
]
|
214 |
+
)
|
215 |
+
result = response.choices[0].message.content
|
216 |
+
logger.debug(f"Result: {str(result)}")
|
217 |
+
|
218 |
+
try:
|
219 |
+
translation = json.loads(result)
|
220 |
+
except Exception as error:
|
221 |
+
match_result = re.search(r'\{.*?\}', result)
|
222 |
+
if match_result:
|
223 |
+
logger.error(str(error))
|
224 |
+
json_str = match_result.group(0)
|
225 |
+
translation = json.loads(json_str)
|
226 |
+
else:
|
227 |
+
raise error
|
228 |
+
|
229 |
+
# Get valid data
|
230 |
+
if batch_lines:
|
231 |
+
for conversation in translation.values():
|
232 |
+
if isinstance(conversation, dict):
|
233 |
+
conversation = list(conversation.values())[0]
|
234 |
+
if (
|
235 |
+
list(
|
236 |
+
original_text["conversation"][0].values()
|
237 |
+
)[0].strip() ==
|
238 |
+
list(conversation[0].values())[0].strip()
|
239 |
+
):
|
240 |
+
continue
|
241 |
+
if len(conversation) == batch_lines:
|
242 |
+
break
|
243 |
+
|
244 |
+
fix_conversation_length = []
|
245 |
+
for line in conversation:
|
246 |
+
for speaker_code, text_tr in line.items():
|
247 |
+
fix_conversation_length.append({speaker_code: text_tr})
|
248 |
+
|
249 |
+
logger.debug(f"Data batch: {str(fix_conversation_length)}")
|
250 |
+
logger.debug(
|
251 |
+
f"Lines Received: {len(fix_conversation_length)},"
|
252 |
+
f" expected: {batch_lines}"
|
253 |
+
)
|
254 |
+
|
255 |
+
return fix_conversation_length
|
256 |
+
|
257 |
+
else:
|
258 |
+
if isinstance(translation, dict):
|
259 |
+
translation = list(translation.values())[0]
|
260 |
+
if isinstance(translation, list):
|
261 |
+
translation = translation[0]
|
262 |
+
if isinstance(translation, set):
|
263 |
+
translation = list(translation)[0]
|
264 |
+
if not isinstance(translation, str):
|
265 |
+
raise ValueError(f"No valid response received: {str(translation)}")
|
266 |
+
|
267 |
+
return translation
|
268 |
+
|
269 |
+
|
270 |
+
def gpt_sequential(segments, model, target, source=None):
|
271 |
+
from openai import OpenAI
|
272 |
+
|
273 |
+
translated_segments = copy.deepcopy(segments)
|
274 |
+
|
275 |
+
client = OpenAI()
|
276 |
+
progress_bar = tqdm(total=len(segments), desc="Translating")
|
277 |
+
|
278 |
+
lang_tg = re.sub(r'\([^)]*\)', '', INVERTED_LANGUAGES[target]).strip()
|
279 |
+
lang_sc = ""
|
280 |
+
if source:
|
281 |
+
lang_sc = re.sub(r'\([^)]*\)', '', INVERTED_LANGUAGES[source]).strip()
|
282 |
+
|
283 |
+
fixed_target = fix_code_language(target)
|
284 |
+
fixed_source = fix_code_language(source) if source else "auto"
|
285 |
+
|
286 |
+
system_prompt = "Machine translation designed to output the translated_text JSON."
|
287 |
+
|
288 |
+
for i, line in enumerate(translated_segments):
|
289 |
+
text = line["text"].strip()
|
290 |
+
start = line["start"]
|
291 |
+
user_prompt = f"Translate the following {lang_sc} text into {lang_tg}, write the fully translated text and nothing more:\n{text}"
|
292 |
+
|
293 |
+
time.sleep(0.5)
|
294 |
+
|
295 |
+
try:
|
296 |
+
translated_text = call_gpt_translate(
|
297 |
+
client,
|
298 |
+
model,
|
299 |
+
system_prompt,
|
300 |
+
user_prompt,
|
301 |
+
)
|
302 |
+
|
303 |
+
except Exception as error:
|
304 |
+
logger.error(
|
305 |
+
f"{str(error)} >> The text of segment {start} "
|
306 |
+
"is being corrected with Google Translate"
|
307 |
+
)
|
308 |
+
translator = GoogleTranslator(
|
309 |
+
source=fixed_source, target=fixed_target
|
310 |
+
)
|
311 |
+
translated_text = translator.translate(text.strip())
|
312 |
+
|
313 |
+
translated_segments[i]["text"] = translated_text.strip()
|
314 |
+
progress_bar.update(1)
|
315 |
+
|
316 |
+
progress_bar.close()
|
317 |
+
|
318 |
+
return translated_segments
|
319 |
+
|
320 |
+
|
321 |
+
def gpt_batch(segments, model, target, token_batch_limit=900, source=None):
|
322 |
+
from openai import OpenAI
|
323 |
+
import tiktoken
|
324 |
+
|
325 |
+
token_batch_limit = max(100, (token_batch_limit - 40) // 2)
|
326 |
+
progress_bar = tqdm(total=len(segments), desc="Translating")
|
327 |
+
segments_copy = copy.deepcopy(segments)
|
328 |
+
encoding = tiktoken.get_encoding("cl100k_base")
|
329 |
+
client = OpenAI()
|
330 |
+
|
331 |
+
lang_tg = re.sub(r'\([^)]*\)', '', INVERTED_LANGUAGES[target]).strip()
|
332 |
+
lang_sc = ""
|
333 |
+
if source:
|
334 |
+
lang_sc = re.sub(r'\([^)]*\)', '', INVERTED_LANGUAGES[source]).strip()
|
335 |
+
|
336 |
+
fixed_target = fix_code_language(target)
|
337 |
+
fixed_source = fix_code_language(source) if source else "auto"
|
338 |
+
|
339 |
+
name_speaker = "ABCDEFGHIJKL"
|
340 |
+
|
341 |
+
translated_lines = []
|
342 |
+
text_data_dict = []
|
343 |
+
num_tokens = 0
|
344 |
+
count_sk = {char: 0 for char in "ABCDEFGHIJKL"}
|
345 |
+
|
346 |
+
for i, line in enumerate(segments_copy):
|
347 |
+
text = line["text"]
|
348 |
+
speaker = line["speaker"]
|
349 |
+
last_start = line["start"]
|
350 |
+
# text_data_dict.append({str(int(speaker[-1])+1): text})
|
351 |
+
index_sk = int(speaker[-2:])
|
352 |
+
character_sk = name_speaker[index_sk]
|
353 |
+
count_sk[character_sk] += 1
|
354 |
+
code_sk = character_sk+str(count_sk[character_sk])
|
355 |
+
text_data_dict.append({code_sk: text})
|
356 |
+
num_tokens += len(encoding.encode(text)) + 7
|
357 |
+
if num_tokens >= token_batch_limit or i == len(segments_copy)-1:
|
358 |
+
try:
|
359 |
+
batch_lines = len(text_data_dict)
|
360 |
+
batch_conversation = {"conversation": copy.deepcopy(text_data_dict)}
|
361 |
+
# Reset vars
|
362 |
+
num_tokens = 0
|
363 |
+
text_data_dict = []
|
364 |
+
count_sk = {char: 0 for char in "ABCDEFGHIJKL"}
|
365 |
+
# Process translation
|
366 |
+
# https://arxiv.org/pdf/2309.03409.pdf
|
367 |
+
system_prompt = f"Machine translation designed to output the translated_conversation key JSON containing a list of {batch_lines} items."
|
368 |
+
user_prompt = f"Translate each of the following text values in conversation{' from' if lang_sc else ''} {lang_sc} to {lang_tg}:\n{batch_conversation}"
|
369 |
+
logger.debug(f"Prompt: {str(user_prompt)}")
|
370 |
+
|
371 |
+
conversation = call_gpt_translate(
|
372 |
+
client,
|
373 |
+
model,
|
374 |
+
system_prompt,
|
375 |
+
user_prompt,
|
376 |
+
original_text=batch_conversation,
|
377 |
+
batch_lines=batch_lines,
|
378 |
+
)
|
379 |
+
|
380 |
+
if len(conversation) < batch_lines:
|
381 |
+
raise ValueError(
|
382 |
+
"Incomplete result received. Batch lines: "
|
383 |
+
f"{len(conversation)}, expected: {batch_lines}"
|
384 |
+
)
|
385 |
+
|
386 |
+
for i, translated_text in enumerate(conversation):
|
387 |
+
if i+1 > batch_lines:
|
388 |
+
break
|
389 |
+
translated_lines.append(list(translated_text.values())[0])
|
390 |
+
|
391 |
+
progress_bar.update(batch_lines)
|
392 |
+
|
393 |
+
except Exception as error:
|
394 |
+
logger.error(str(error))
|
395 |
+
|
396 |
+
first_start = segments_copy[max(0, i-(batch_lines-1))]["start"]
|
397 |
+
logger.warning(
|
398 |
+
f"The batch from {first_start} to {last_start} "
|
399 |
+
"failed, is being corrected with Google Translate"
|
400 |
+
)
|
401 |
+
|
402 |
+
translator = GoogleTranslator(
|
403 |
+
source=fixed_source,
|
404 |
+
target=fixed_target
|
405 |
+
)
|
406 |
+
|
407 |
+
for txt_source in batch_conversation["conversation"]:
|
408 |
+
translated_txt = translator.translate(
|
409 |
+
list(txt_source.values())[0].strip()
|
410 |
+
)
|
411 |
+
translated_lines.append(translated_txt.strip())
|
412 |
+
progress_bar.update(1)
|
413 |
+
|
414 |
+
progress_bar.close()
|
415 |
+
|
416 |
+
return verify_translate(
|
417 |
+
segments, segments_copy, translated_lines, fixed_target, fixed_source
|
418 |
+
)
|
419 |
+
|
420 |
+
|
421 |
+
def translate_text(
|
422 |
+
segments,
|
423 |
+
target,
|
424 |
+
translation_process="google_translator_batch",
|
425 |
+
chunk_size=4500,
|
426 |
+
source=None,
|
427 |
+
token_batch_limit=1000,
|
428 |
+
):
|
429 |
+
"""Translates text segments using a specified process."""
|
430 |
+
match translation_process:
|
431 |
+
case "google_translator_batch":
|
432 |
+
return translate_batch(
|
433 |
+
segments,
|
434 |
+
fix_code_language(target),
|
435 |
+
chunk_size,
|
436 |
+
fix_code_language(source)
|
437 |
+
)
|
438 |
+
case "google_translator":
|
439 |
+
return translate_iterative(
|
440 |
+
segments,
|
441 |
+
fix_code_language(target),
|
442 |
+
fix_code_language(source)
|
443 |
+
)
|
444 |
+
case model if model in ["gpt-3.5-turbo-0125", "gpt-4-turbo-preview"]:
|
445 |
+
return gpt_sequential(segments, model, target, source)
|
446 |
+
case model if model in ["gpt-3.5-turbo-0125_batch", "gpt-4-turbo-preview_batch",]:
|
447 |
+
return gpt_batch(
|
448 |
+
segments,
|
449 |
+
translation_process.replace("_batch", ""),
|
450 |
+
target,
|
451 |
+
token_batch_limit,
|
452 |
+
source
|
453 |
+
)
|
454 |
+
case "disable_translation":
|
455 |
+
return segments
|
456 |
+
case _:
|
457 |
+
raise ValueError("No valid translation process")
|
soni_translate/utils.py
ADDED
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, zipfile, rarfile, shutil, subprocess, shlex, sys # noqa
|
2 |
+
from .logging_setup import logger
|
3 |
+
from urllib.parse import urlparse
|
4 |
+
from IPython.utils import capture
|
5 |
+
import re
|
6 |
+
|
7 |
+
VIDEO_EXTENSIONS = [
|
8 |
+
".mp4",
|
9 |
+
".avi",
|
10 |
+
".mov",
|
11 |
+
".mkv",
|
12 |
+
".wmv",
|
13 |
+
".flv",
|
14 |
+
".webm",
|
15 |
+
".m4v",
|
16 |
+
".mpeg",
|
17 |
+
".mpg",
|
18 |
+
".3gp"
|
19 |
+
]
|
20 |
+
|
21 |
+
AUDIO_EXTENSIONS = [
|
22 |
+
".mp3",
|
23 |
+
".wav",
|
24 |
+
".aiff",
|
25 |
+
".aif",
|
26 |
+
".flac",
|
27 |
+
".aac",
|
28 |
+
".ogg",
|
29 |
+
".wma",
|
30 |
+
".m4a",
|
31 |
+
".alac",
|
32 |
+
".pcm",
|
33 |
+
".opus",
|
34 |
+
".ape",
|
35 |
+
".amr",
|
36 |
+
".ac3",
|
37 |
+
".vox",
|
38 |
+
".caf"
|
39 |
+
]
|
40 |
+
|
41 |
+
SUBTITLE_EXTENSIONS = [
|
42 |
+
".srt",
|
43 |
+
".vtt",
|
44 |
+
".ass"
|
45 |
+
]
|
46 |
+
|
47 |
+
|
48 |
+
def run_command(command):
|
49 |
+
logger.debug(command)
|
50 |
+
if isinstance(command, str):
|
51 |
+
command = shlex.split(command)
|
52 |
+
|
53 |
+
sub_params = {
|
54 |
+
"stdout": subprocess.PIPE,
|
55 |
+
"stderr": subprocess.PIPE,
|
56 |
+
"creationflags": subprocess.CREATE_NO_WINDOW
|
57 |
+
if sys.platform == "win32"
|
58 |
+
else 0,
|
59 |
+
}
|
60 |
+
process_command = subprocess.Popen(command, **sub_params)
|
61 |
+
output, errors = process_command.communicate()
|
62 |
+
if (
|
63 |
+
process_command.returncode != 0
|
64 |
+
): # or not os.path.exists(mono_path) or os.path.getsize(mono_path) == 0:
|
65 |
+
logger.error("Error comnand")
|
66 |
+
raise Exception(errors.decode())
|
67 |
+
|
68 |
+
|
69 |
+
def print_tree_directory(root_dir, indent=""):
|
70 |
+
if not os.path.exists(root_dir):
|
71 |
+
logger.error(f"{indent} Invalid directory or file: {root_dir}")
|
72 |
+
return
|
73 |
+
|
74 |
+
items = os.listdir(root_dir)
|
75 |
+
|
76 |
+
for index, item in enumerate(sorted(items)):
|
77 |
+
item_path = os.path.join(root_dir, item)
|
78 |
+
is_last_item = index == len(items) - 1
|
79 |
+
|
80 |
+
if os.path.isfile(item_path) and item_path.endswith(".zip"):
|
81 |
+
with zipfile.ZipFile(item_path, "r") as zip_file:
|
82 |
+
print(
|
83 |
+
f"{indent}{'└──' if is_last_item else '├──'} {item} (zip file)"
|
84 |
+
)
|
85 |
+
zip_contents = zip_file.namelist()
|
86 |
+
for zip_item in sorted(zip_contents):
|
87 |
+
print(
|
88 |
+
f"{indent}{' ' if is_last_item else '│ '}{zip_item}"
|
89 |
+
)
|
90 |
+
else:
|
91 |
+
print(f"{indent}{'└──' if is_last_item else '├──'} {item}")
|
92 |
+
|
93 |
+
if os.path.isdir(item_path):
|
94 |
+
new_indent = indent + (" " if is_last_item else "│ ")
|
95 |
+
print_tree_directory(item_path, new_indent)
|
96 |
+
|
97 |
+
|
98 |
+
def upload_model_list():
|
99 |
+
weight_root = "weights"
|
100 |
+
models = []
|
101 |
+
for name in os.listdir(weight_root):
|
102 |
+
if name.endswith(".pth"):
|
103 |
+
models.append("weights/" + name)
|
104 |
+
if models:
|
105 |
+
logger.debug(models)
|
106 |
+
|
107 |
+
index_root = "logs"
|
108 |
+
index_paths = [None]
|
109 |
+
for name in os.listdir(index_root):
|
110 |
+
if name.endswith(".index"):
|
111 |
+
index_paths.append("logs/" + name)
|
112 |
+
if index_paths:
|
113 |
+
logger.debug(index_paths)
|
114 |
+
|
115 |
+
return models, index_paths
|
116 |
+
|
117 |
+
|
118 |
+
def manual_download(url, dst):
|
119 |
+
if "drive.google" in url:
|
120 |
+
logger.info("Drive url")
|
121 |
+
if "folders" in url:
|
122 |
+
logger.info("folder")
|
123 |
+
os.system(f'gdown --folder "{url}" -O {dst} --fuzzy -c')
|
124 |
+
else:
|
125 |
+
logger.info("single")
|
126 |
+
os.system(f'gdown "{url}" -O {dst} --fuzzy -c')
|
127 |
+
elif "huggingface" in url:
|
128 |
+
logger.info("HuggingFace url")
|
129 |
+
if "/blob/" in url or "/resolve/" in url:
|
130 |
+
if "/blob/" in url:
|
131 |
+
url = url.replace("/blob/", "/resolve/")
|
132 |
+
download_manager(url=url, path=dst, overwrite=True, progress=True)
|
133 |
+
else:
|
134 |
+
os.system(f"git clone {url} {dst+'repo/'}")
|
135 |
+
elif "http" in url:
|
136 |
+
logger.info("URL")
|
137 |
+
download_manager(url=url, path=dst, overwrite=True, progress=True)
|
138 |
+
elif os.path.exists(url):
|
139 |
+
logger.info("Path")
|
140 |
+
copy_files(url, dst)
|
141 |
+
else:
|
142 |
+
logger.error(f"No valid URL: {url}")
|
143 |
+
|
144 |
+
|
145 |
+
def download_list(text_downloads):
|
146 |
+
|
147 |
+
if os.environ.get("ZERO_GPU") == "TRUE":
|
148 |
+
raise RuntimeError("This option is disabled in this demo.")
|
149 |
+
|
150 |
+
try:
|
151 |
+
urls = [elem.strip() for elem in text_downloads.split(",")]
|
152 |
+
except Exception as error:
|
153 |
+
raise ValueError(f"No valid URL. {str(error)}")
|
154 |
+
|
155 |
+
create_directories(["downloads", "logs", "weights"])
|
156 |
+
|
157 |
+
path_download = "downloads/"
|
158 |
+
for url in urls:
|
159 |
+
manual_download(url, path_download)
|
160 |
+
|
161 |
+
# Tree
|
162 |
+
print("####################################")
|
163 |
+
print_tree_directory("downloads", indent="")
|
164 |
+
print("####################################")
|
165 |
+
|
166 |
+
# Place files
|
167 |
+
select_zip_and_rar_files("downloads/")
|
168 |
+
|
169 |
+
models, _ = upload_model_list()
|
170 |
+
|
171 |
+
# hf space models files delete
|
172 |
+
remove_directory_contents("downloads/repo")
|
173 |
+
|
174 |
+
return f"Downloaded = {models}"
|
175 |
+
|
176 |
+
|
177 |
+
def select_zip_and_rar_files(directory_path="downloads/"):
|
178 |
+
# filter
|
179 |
+
zip_files = []
|
180 |
+
rar_files = []
|
181 |
+
|
182 |
+
for file_name in os.listdir(directory_path):
|
183 |
+
if file_name.endswith(".zip"):
|
184 |
+
zip_files.append(file_name)
|
185 |
+
elif file_name.endswith(".rar"):
|
186 |
+
rar_files.append(file_name)
|
187 |
+
|
188 |
+
# extract
|
189 |
+
for file_name in zip_files:
|
190 |
+
file_path = os.path.join(directory_path, file_name)
|
191 |
+
with zipfile.ZipFile(file_path, "r") as zip_ref:
|
192 |
+
zip_ref.extractall(directory_path)
|
193 |
+
|
194 |
+
for file_name in rar_files:
|
195 |
+
file_path = os.path.join(directory_path, file_name)
|
196 |
+
with rarfile.RarFile(file_path, "r") as rar_ref:
|
197 |
+
rar_ref.extractall(directory_path)
|
198 |
+
|
199 |
+
# set in path
|
200 |
+
def move_files_with_extension(src_dir, extension, destination_dir):
|
201 |
+
for root, _, files in os.walk(src_dir):
|
202 |
+
for file_name in files:
|
203 |
+
if file_name.endswith(extension):
|
204 |
+
source_file = os.path.join(root, file_name)
|
205 |
+
destination = os.path.join(destination_dir, file_name)
|
206 |
+
shutil.move(source_file, destination)
|
207 |
+
|
208 |
+
move_files_with_extension(directory_path, ".index", "logs/")
|
209 |
+
move_files_with_extension(directory_path, ".pth", "weights/")
|
210 |
+
|
211 |
+
return "Download complete"
|
212 |
+
|
213 |
+
|
214 |
+
def is_file_with_extensions(string_path, extensions):
|
215 |
+
return any(string_path.lower().endswith(ext) for ext in extensions)
|
216 |
+
|
217 |
+
|
218 |
+
def is_video_file(string_path):
|
219 |
+
return is_file_with_extensions(string_path, VIDEO_EXTENSIONS)
|
220 |
+
|
221 |
+
|
222 |
+
def is_audio_file(string_path):
|
223 |
+
return is_file_with_extensions(string_path, AUDIO_EXTENSIONS)
|
224 |
+
|
225 |
+
|
226 |
+
def is_subtitle_file(string_path):
|
227 |
+
return is_file_with_extensions(string_path, SUBTITLE_EXTENSIONS)
|
228 |
+
|
229 |
+
|
230 |
+
def get_directory_files(directory):
|
231 |
+
audio_files = []
|
232 |
+
video_files = []
|
233 |
+
sub_files = []
|
234 |
+
|
235 |
+
for item in os.listdir(directory):
|
236 |
+
item_path = os.path.join(directory, item)
|
237 |
+
|
238 |
+
if os.path.isfile(item_path):
|
239 |
+
|
240 |
+
if is_audio_file(item_path):
|
241 |
+
audio_files.append(item_path)
|
242 |
+
|
243 |
+
elif is_video_file(item_path):
|
244 |
+
video_files.append(item_path)
|
245 |
+
|
246 |
+
elif is_subtitle_file(item_path):
|
247 |
+
sub_files.append(item_path)
|
248 |
+
|
249 |
+
logger.info(
|
250 |
+
f"Files in path ({directory}): "
|
251 |
+
f"{str(audio_files + video_files + sub_files)}"
|
252 |
+
)
|
253 |
+
|
254 |
+
return audio_files, video_files, sub_files
|
255 |
+
|
256 |
+
|
257 |
+
def get_valid_files(paths):
|
258 |
+
valid_paths = []
|
259 |
+
for path in paths:
|
260 |
+
if os.path.isdir(path):
|
261 |
+
audio_files, video_files, sub_files = get_directory_files(path)
|
262 |
+
valid_paths.extend(audio_files)
|
263 |
+
valid_paths.extend(video_files)
|
264 |
+
valid_paths.extend(sub_files)
|
265 |
+
else:
|
266 |
+
valid_paths.append(path)
|
267 |
+
|
268 |
+
return valid_paths
|
269 |
+
|
270 |
+
|
271 |
+
def extract_video_links(link):
|
272 |
+
|
273 |
+
params_dlp = {"quiet": False, "no_warnings": True, "noplaylist": False}
|
274 |
+
|
275 |
+
try:
|
276 |
+
from yt_dlp import YoutubeDL
|
277 |
+
with capture.capture_output() as cap:
|
278 |
+
with YoutubeDL(params_dlp) as ydl:
|
279 |
+
info_dict = ydl.extract_info( # noqa
|
280 |
+
link, download=False, process=True
|
281 |
+
)
|
282 |
+
|
283 |
+
urls = re.findall(r'\[youtube\] Extracting URL: (.*?)\n', cap.stdout)
|
284 |
+
logger.info(f"List of videos in ({link}): {str(urls)}")
|
285 |
+
del cap
|
286 |
+
except Exception as error:
|
287 |
+
logger.error(f"{link} >> {str(error)}")
|
288 |
+
urls = [link]
|
289 |
+
|
290 |
+
return urls
|
291 |
+
|
292 |
+
|
293 |
+
def get_link_list(urls):
|
294 |
+
valid_links = []
|
295 |
+
for url_video in urls:
|
296 |
+
if "youtube.com" in url_video and "/watch?v=" not in url_video:
|
297 |
+
url_links = extract_video_links(url_video)
|
298 |
+
valid_links.extend(url_links)
|
299 |
+
else:
|
300 |
+
valid_links.append(url_video)
|
301 |
+
return valid_links
|
302 |
+
|
303 |
+
# =====================================
|
304 |
+
# Download Manager
|
305 |
+
# =====================================
|
306 |
+
|
307 |
+
|
308 |
+
def load_file_from_url(
|
309 |
+
url: str,
|
310 |
+
model_dir: str,
|
311 |
+
file_name: str | None = None,
|
312 |
+
overwrite: bool = False,
|
313 |
+
progress: bool = True,
|
314 |
+
) -> str:
|
315 |
+
"""Download a file from `url` into `model_dir`,
|
316 |
+
using the file present if possible.
|
317 |
+
|
318 |
+
Returns the path to the downloaded file.
|
319 |
+
"""
|
320 |
+
os.makedirs(model_dir, exist_ok=True)
|
321 |
+
if not file_name:
|
322 |
+
parts = urlparse(url)
|
323 |
+
file_name = os.path.basename(parts.path)
|
324 |
+
cached_file = os.path.abspath(os.path.join(model_dir, file_name))
|
325 |
+
|
326 |
+
# Overwrite
|
327 |
+
if os.path.exists(cached_file):
|
328 |
+
if overwrite or os.path.getsize(cached_file) == 0:
|
329 |
+
remove_files(cached_file)
|
330 |
+
|
331 |
+
# Download
|
332 |
+
if not os.path.exists(cached_file):
|
333 |
+
logger.info(f'Downloading: "{url}" to {cached_file}\n')
|
334 |
+
from torch.hub import download_url_to_file
|
335 |
+
|
336 |
+
download_url_to_file(url, cached_file, progress=progress)
|
337 |
+
else:
|
338 |
+
logger.debug(cached_file)
|
339 |
+
|
340 |
+
return cached_file
|
341 |
+
|
342 |
+
|
343 |
+
def friendly_name(file: str):
|
344 |
+
if file.startswith("http"):
|
345 |
+
file = urlparse(file).path
|
346 |
+
|
347 |
+
file = os.path.basename(file)
|
348 |
+
model_name, extension = os.path.splitext(file)
|
349 |
+
return model_name, extension
|
350 |
+
|
351 |
+
|
352 |
+
def download_manager(
|
353 |
+
url: str,
|
354 |
+
path: str,
|
355 |
+
extension: str = "",
|
356 |
+
overwrite: bool = False,
|
357 |
+
progress: bool = True,
|
358 |
+
):
|
359 |
+
url = url.strip()
|
360 |
+
|
361 |
+
name, ext = friendly_name(url)
|
362 |
+
name += ext if not extension else f".{extension}"
|
363 |
+
|
364 |
+
if url.startswith("http"):
|
365 |
+
filename = load_file_from_url(
|
366 |
+
url=url,
|
367 |
+
model_dir=path,
|
368 |
+
file_name=name,
|
369 |
+
overwrite=overwrite,
|
370 |
+
progress=progress,
|
371 |
+
)
|
372 |
+
else:
|
373 |
+
filename = path
|
374 |
+
|
375 |
+
return filename
|
376 |
+
|
377 |
+
|
378 |
+
# =====================================
|
379 |
+
# File management
|
380 |
+
# =====================================
|
381 |
+
|
382 |
+
|
383 |
+
# only remove files
|
384 |
+
def remove_files(file_list):
|
385 |
+
if isinstance(file_list, str):
|
386 |
+
file_list = [file_list]
|
387 |
+
|
388 |
+
for file in file_list:
|
389 |
+
if os.path.exists(file):
|
390 |
+
os.remove(file)
|
391 |
+
|
392 |
+
|
393 |
+
def remove_directory_contents(directory_path):
|
394 |
+
"""
|
395 |
+
Removes all files and subdirectories within a directory.
|
396 |
+
|
397 |
+
Parameters:
|
398 |
+
directory_path (str): Path to the directory whose
|
399 |
+
contents need to be removed.
|
400 |
+
"""
|
401 |
+
if os.path.exists(directory_path):
|
402 |
+
for filename in os.listdir(directory_path):
|
403 |
+
file_path = os.path.join(directory_path, filename)
|
404 |
+
try:
|
405 |
+
if os.path.isfile(file_path):
|
406 |
+
os.remove(file_path)
|
407 |
+
elif os.path.isdir(file_path):
|
408 |
+
shutil.rmtree(file_path)
|
409 |
+
except Exception as e:
|
410 |
+
logger.error(f"Failed to delete {file_path}. Reason: {e}")
|
411 |
+
logger.info(f"Content in '{directory_path}' removed.")
|
412 |
+
else:
|
413 |
+
logger.error(f"Directory '{directory_path}' does not exist.")
|
414 |
+
|
415 |
+
|
416 |
+
# Create directory if not exists
|
417 |
+
def create_directories(directory_path):
|
418 |
+
if isinstance(directory_path, str):
|
419 |
+
directory_path = [directory_path]
|
420 |
+
for one_dir_path in directory_path:
|
421 |
+
if not os.path.exists(one_dir_path):
|
422 |
+
os.makedirs(one_dir_path)
|
423 |
+
logger.debug(f"Directory '{one_dir_path}' created.")
|
424 |
+
|
425 |
+
|
426 |
+
def move_files(source_dir, destination_dir, extension=""):
|
427 |
+
"""
|
428 |
+
Moves file(s) from the source path to the destination path.
|
429 |
+
|
430 |
+
Parameters:
|
431 |
+
source_dir (str): Path to the source directory.
|
432 |
+
destination_dir (str): Path to the destination directory.
|
433 |
+
extension (str): Only move files with this extension.
|
434 |
+
"""
|
435 |
+
create_directories(destination_dir)
|
436 |
+
|
437 |
+
for filename in os.listdir(source_dir):
|
438 |
+
source_path = os.path.join(source_dir, filename)
|
439 |
+
destination_path = os.path.join(destination_dir, filename)
|
440 |
+
if extension and not filename.endswith(extension):
|
441 |
+
continue
|
442 |
+
os.replace(source_path, destination_path)
|
443 |
+
|
444 |
+
|
445 |
+
def copy_files(source_path, destination_path):
|
446 |
+
"""
|
447 |
+
Copies a file or multiple files from a source path to a destination path.
|
448 |
+
|
449 |
+
Parameters:
|
450 |
+
source_path (str or list): Path or list of paths to the source
|
451 |
+
file(s) or directory.
|
452 |
+
destination_path (str): Path to the destination directory.
|
453 |
+
"""
|
454 |
+
create_directories(destination_path)
|
455 |
+
|
456 |
+
if isinstance(source_path, str):
|
457 |
+
source_path = [source_path]
|
458 |
+
|
459 |
+
if os.path.isdir(source_path[0]):
|
460 |
+
# Copy all files from the source directory to the destination directory
|
461 |
+
base_path = source_path[0]
|
462 |
+
source_path = os.listdir(source_path[0])
|
463 |
+
source_path = [
|
464 |
+
os.path.join(base_path, file_name) for file_name in source_path
|
465 |
+
]
|
466 |
+
|
467 |
+
for one_source_path in source_path:
|
468 |
+
if os.path.exists(one_source_path):
|
469 |
+
shutil.copy2(one_source_path, destination_path)
|
470 |
+
logger.debug(
|
471 |
+
f"File '{one_source_path}' copied to '{destination_path}'."
|
472 |
+
)
|
473 |
+
else:
|
474 |
+
logger.error(f"File '{one_source_path}' does not exist.")
|
475 |
+
|
476 |
+
|
477 |
+
def rename_file(current_name, new_name):
|
478 |
+
file_directory = os.path.dirname(current_name)
|
479 |
+
|
480 |
+
if os.path.exists(current_name):
|
481 |
+
dir_new_name_file = os.path.join(file_directory, new_name)
|
482 |
+
os.rename(current_name, dir_new_name_file)
|
483 |
+
logger.debug(f"File '{current_name}' renamed to '{new_name}'.")
|
484 |
+
return dir_new_name_file
|
485 |
+
else:
|
486 |
+
logger.error(f"File '{current_name}' does not exist.")
|
487 |
+
return None
|
vci_pipeline.py
ADDED
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np, parselmouth, torch, pdb, sys
|
2 |
+
from time import time as ttime
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import scipy.signal as signal
|
5 |
+
import pyworld, os, traceback, faiss, librosa, torchcrepe
|
6 |
+
from scipy import signal
|
7 |
+
from functools import lru_cache
|
8 |
+
from soni_translate.logging_setup import logger
|
9 |
+
|
10 |
+
now_dir = os.getcwd()
|
11 |
+
sys.path.append(now_dir)
|
12 |
+
|
13 |
+
bh, ah = signal.butter(N=5, Wn=48, btype="high", fs=16000)
|
14 |
+
|
15 |
+
input_audio_path2wav = {}
|
16 |
+
|
17 |
+
|
18 |
+
@lru_cache
|
19 |
+
def cache_harvest_f0(input_audio_path, fs, f0max, f0min, frame_period):
|
20 |
+
audio = input_audio_path2wav[input_audio_path]
|
21 |
+
f0, t = pyworld.harvest(
|
22 |
+
audio,
|
23 |
+
fs=fs,
|
24 |
+
f0_ceil=f0max,
|
25 |
+
f0_floor=f0min,
|
26 |
+
frame_period=frame_period,
|
27 |
+
)
|
28 |
+
f0 = pyworld.stonemask(audio, f0, t, fs)
|
29 |
+
return f0
|
30 |
+
|
31 |
+
|
32 |
+
def change_rms(data1, sr1, data2, sr2, rate): # 1 is the input audio, 2 is the output audio, rate is the proportion of 2
|
33 |
+
# print(data1.max(),data2.max())
|
34 |
+
rms1 = librosa.feature.rms(
|
35 |
+
y=data1, frame_length=sr1 // 2 * 2, hop_length=sr1 // 2
|
36 |
+
) # one dot every half second
|
37 |
+
rms2 = librosa.feature.rms(y=data2, frame_length=sr2 // 2 * 2, hop_length=sr2 // 2)
|
38 |
+
rms1 = torch.from_numpy(rms1)
|
39 |
+
rms1 = F.interpolate(
|
40 |
+
rms1.unsqueeze(0), size=data2.shape[0], mode="linear"
|
41 |
+
).squeeze()
|
42 |
+
rms2 = torch.from_numpy(rms2)
|
43 |
+
rms2 = F.interpolate(
|
44 |
+
rms2.unsqueeze(0), size=data2.shape[0], mode="linear"
|
45 |
+
).squeeze()
|
46 |
+
rms2 = torch.max(rms2, torch.zeros_like(rms2) + 1e-6)
|
47 |
+
data2 *= (
|
48 |
+
torch.pow(rms1, torch.tensor(1 - rate))
|
49 |
+
* torch.pow(rms2, torch.tensor(rate - 1))
|
50 |
+
).numpy()
|
51 |
+
return data2
|
52 |
+
|
53 |
+
|
54 |
+
class VC(object):
|
55 |
+
def __init__(self, tgt_sr, config):
|
56 |
+
self.x_pad, self.x_query, self.x_center, self.x_max, self.is_half = (
|
57 |
+
config.x_pad,
|
58 |
+
config.x_query,
|
59 |
+
config.x_center,
|
60 |
+
config.x_max,
|
61 |
+
config.is_half,
|
62 |
+
)
|
63 |
+
self.sr = 16000 # hubert input sampling rate
|
64 |
+
self.window = 160 # points per frame
|
65 |
+
self.t_pad = self.sr * self.x_pad # Pad time before and after each bar
|
66 |
+
self.t_pad_tgt = tgt_sr * self.x_pad
|
67 |
+
self.t_pad2 = self.t_pad * 2
|
68 |
+
self.t_query = self.sr * self.x_query # Query time before and after the cut point
|
69 |
+
self.t_center = self.sr * self.x_center # Query point cut position
|
70 |
+
self.t_max = self.sr * self.x_max # Query-free duration threshold
|
71 |
+
self.device = config.device
|
72 |
+
|
73 |
+
def get_f0(
|
74 |
+
self,
|
75 |
+
input_audio_path,
|
76 |
+
x,
|
77 |
+
p_len,
|
78 |
+
f0_up_key,
|
79 |
+
f0_method,
|
80 |
+
filter_radius,
|
81 |
+
inp_f0=None,
|
82 |
+
):
|
83 |
+
global input_audio_path2wav
|
84 |
+
time_step = self.window / self.sr * 1000
|
85 |
+
f0_min = 50
|
86 |
+
f0_max = 1100
|
87 |
+
f0_mel_min = 1127 * np.log(1 + f0_min / 700)
|
88 |
+
f0_mel_max = 1127 * np.log(1 + f0_max / 700)
|
89 |
+
if f0_method == "pm":
|
90 |
+
f0 = (
|
91 |
+
parselmouth.Sound(x, self.sr)
|
92 |
+
.to_pitch_ac(
|
93 |
+
time_step=time_step / 1000,
|
94 |
+
voicing_threshold=0.6,
|
95 |
+
pitch_floor=f0_min,
|
96 |
+
pitch_ceiling=f0_max,
|
97 |
+
)
|
98 |
+
.selected_array["frequency"]
|
99 |
+
)
|
100 |
+
pad_size = (p_len - len(f0) + 1) // 2
|
101 |
+
if pad_size > 0 or p_len - len(f0) - pad_size > 0:
|
102 |
+
f0 = np.pad(
|
103 |
+
f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant"
|
104 |
+
)
|
105 |
+
elif f0_method == "harvest":
|
106 |
+
input_audio_path2wav[input_audio_path] = x.astype(np.double)
|
107 |
+
f0 = cache_harvest_f0(input_audio_path, self.sr, f0_max, f0_min, 10)
|
108 |
+
if filter_radius > 2:
|
109 |
+
f0 = signal.medfilt(f0, 3)
|
110 |
+
elif f0_method == "crepe":
|
111 |
+
model = "full"
|
112 |
+
# Pick a batch size that doesn't cause memory errors on your gpu
|
113 |
+
batch_size = 512
|
114 |
+
# Compute pitch using first gpu
|
115 |
+
audio = torch.tensor(np.copy(x))[None].float()
|
116 |
+
f0, pd = torchcrepe.predict(
|
117 |
+
audio,
|
118 |
+
self.sr,
|
119 |
+
self.window,
|
120 |
+
f0_min,
|
121 |
+
f0_max,
|
122 |
+
model,
|
123 |
+
batch_size=batch_size,
|
124 |
+
device=self.device,
|
125 |
+
return_periodicity=True,
|
126 |
+
)
|
127 |
+
pd = torchcrepe.filter.median(pd, 3)
|
128 |
+
f0 = torchcrepe.filter.mean(f0, 3)
|
129 |
+
f0[pd < 0.1] = 0
|
130 |
+
f0 = f0[0].cpu().numpy()
|
131 |
+
elif "rmvpe" in f0_method:
|
132 |
+
if hasattr(self, "model_rmvpe") == False:
|
133 |
+
from lib.rmvpe import RMVPE
|
134 |
+
|
135 |
+
logger.info("Loading vocal pitch estimator model")
|
136 |
+
self.model_rmvpe = RMVPE(
|
137 |
+
"rmvpe.pt", is_half=self.is_half, device=self.device
|
138 |
+
)
|
139 |
+
thred = 0.03
|
140 |
+
if "+" in f0_method:
|
141 |
+
f0 = self.model_rmvpe.pitch_based_audio_inference(x, thred, f0_min, f0_max)
|
142 |
+
else:
|
143 |
+
f0 = self.model_rmvpe.infer_from_audio(x, thred)
|
144 |
+
|
145 |
+
f0 *= pow(2, f0_up_key / 12)
|
146 |
+
# with open("test.txt","w")as f:f.write("\n".join([str(i)for i in f0.tolist()]))
|
147 |
+
tf0 = self.sr // self.window # f0 points per second
|
148 |
+
if inp_f0 is not None:
|
149 |
+
delta_t = np.round(
|
150 |
+
(inp_f0[:, 0].max() - inp_f0[:, 0].min()) * tf0 + 1
|
151 |
+
).astype("int16")
|
152 |
+
replace_f0 = np.interp(
|
153 |
+
list(range(delta_t)), inp_f0[:, 0] * 100, inp_f0[:, 1]
|
154 |
+
)
|
155 |
+
shape = f0[self.x_pad * tf0 : self.x_pad * tf0 + len(replace_f0)].shape[0]
|
156 |
+
f0[self.x_pad * tf0 : self.x_pad * tf0 + len(replace_f0)] = replace_f0[
|
157 |
+
:shape
|
158 |
+
]
|
159 |
+
# with open("test_opt.txt","w")as f:f.write("\n".join([str(i)for i in f0.tolist()]))
|
160 |
+
f0bak = f0.copy()
|
161 |
+
f0_mel = 1127 * np.log(1 + f0 / 700)
|
162 |
+
f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * 254 / (
|
163 |
+
f0_mel_max - f0_mel_min
|
164 |
+
) + 1
|
165 |
+
f0_mel[f0_mel <= 1] = 1
|
166 |
+
f0_mel[f0_mel > 255] = 255
|
167 |
+
try:
|
168 |
+
f0_coarse = np.rint(f0_mel).astype(np.int)
|
169 |
+
except: # noqa
|
170 |
+
f0_coarse = np.rint(f0_mel).astype(int)
|
171 |
+
return f0_coarse, f0bak # 1-0
|
172 |
+
|
173 |
+
def vc(
|
174 |
+
self,
|
175 |
+
model,
|
176 |
+
net_g,
|
177 |
+
sid,
|
178 |
+
audio0,
|
179 |
+
pitch,
|
180 |
+
pitchf,
|
181 |
+
times,
|
182 |
+
index,
|
183 |
+
big_npy,
|
184 |
+
index_rate,
|
185 |
+
version,
|
186 |
+
protect,
|
187 |
+
): # ,file_index,file_big_npy
|
188 |
+
feats = torch.from_numpy(audio0)
|
189 |
+
if self.is_half:
|
190 |
+
feats = feats.half()
|
191 |
+
else:
|
192 |
+
feats = feats.float()
|
193 |
+
if feats.dim() == 2: # double channels
|
194 |
+
feats = feats.mean(-1)
|
195 |
+
assert feats.dim() == 1, feats.dim()
|
196 |
+
feats = feats.view(1, -1)
|
197 |
+
padding_mask = torch.BoolTensor(feats.shape).to(self.device).fill_(False)
|
198 |
+
|
199 |
+
inputs = {
|
200 |
+
"source": feats.to(self.device),
|
201 |
+
"padding_mask": padding_mask,
|
202 |
+
"output_layer": 9 if version == "v1" else 12,
|
203 |
+
}
|
204 |
+
t0 = ttime()
|
205 |
+
with torch.no_grad():
|
206 |
+
logits = model.extract_features(**inputs)
|
207 |
+
feats = model.final_proj(logits[0]) if version == "v1" else logits[0]
|
208 |
+
if protect < 0.5 and pitch != None and pitchf != None:
|
209 |
+
feats0 = feats.clone()
|
210 |
+
if (
|
211 |
+
isinstance(index, type(None)) == False
|
212 |
+
and isinstance(big_npy, type(None)) == False
|
213 |
+
and index_rate != 0
|
214 |
+
):
|
215 |
+
npy = feats[0].cpu().numpy()
|
216 |
+
if self.is_half:
|
217 |
+
npy = npy.astype("float32")
|
218 |
+
|
219 |
+
# _, I = index.search(npy, 1)
|
220 |
+
# npy = big_npy[I.squeeze()]
|
221 |
+
|
222 |
+
score, ix = index.search(npy, k=8)
|
223 |
+
weight = np.square(1 / score)
|
224 |
+
weight /= weight.sum(axis=1, keepdims=True)
|
225 |
+
npy = np.sum(big_npy[ix] * np.expand_dims(weight, axis=2), axis=1)
|
226 |
+
|
227 |
+
if self.is_half:
|
228 |
+
npy = npy.astype("float16")
|
229 |
+
feats = (
|
230 |
+
torch.from_numpy(npy).unsqueeze(0).to(self.device) * index_rate
|
231 |
+
+ (1 - index_rate) * feats
|
232 |
+
)
|
233 |
+
|
234 |
+
feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
|
235 |
+
if protect < 0.5 and pitch != None and pitchf != None:
|
236 |
+
feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(
|
237 |
+
0, 2, 1
|
238 |
+
)
|
239 |
+
t1 = ttime()
|
240 |
+
p_len = audio0.shape[0] // self.window
|
241 |
+
if feats.shape[1] < p_len:
|
242 |
+
p_len = feats.shape[1]
|
243 |
+
if pitch != None and pitchf != None:
|
244 |
+
pitch = pitch[:, :p_len]
|
245 |
+
pitchf = pitchf[:, :p_len]
|
246 |
+
|
247 |
+
if protect < 0.5 and pitch != None and pitchf != None:
|
248 |
+
pitchff = pitchf.clone()
|
249 |
+
pitchff[pitchf > 0] = 1
|
250 |
+
pitchff[pitchf < 1] = protect
|
251 |
+
pitchff = pitchff.unsqueeze(-1)
|
252 |
+
feats = feats * pitchff + feats0 * (1 - pitchff)
|
253 |
+
feats = feats.to(feats0.dtype)
|
254 |
+
p_len = torch.tensor([p_len], device=self.device).long()
|
255 |
+
with torch.no_grad():
|
256 |
+
if pitch != None and pitchf != None:
|
257 |
+
audio1 = (
|
258 |
+
(net_g.infer(feats, p_len, pitch, pitchf, sid)[0][0, 0])
|
259 |
+
.data.cpu()
|
260 |
+
.float()
|
261 |
+
.numpy()
|
262 |
+
)
|
263 |
+
else:
|
264 |
+
audio1 = (
|
265 |
+
(net_g.infer(feats, p_len, sid)[0][0, 0]).data.cpu().float().numpy()
|
266 |
+
)
|
267 |
+
del feats, p_len, padding_mask
|
268 |
+
if torch.cuda.is_available():
|
269 |
+
torch.cuda.empty_cache()
|
270 |
+
t2 = ttime()
|
271 |
+
times[0] += t1 - t0
|
272 |
+
times[2] += t2 - t1
|
273 |
+
return audio1
|
274 |
+
|
275 |
+
def pipeline(
|
276 |
+
self,
|
277 |
+
model,
|
278 |
+
net_g,
|
279 |
+
sid,
|
280 |
+
audio,
|
281 |
+
input_audio_path,
|
282 |
+
times,
|
283 |
+
f0_up_key,
|
284 |
+
f0_method,
|
285 |
+
file_index,
|
286 |
+
# file_big_npy,
|
287 |
+
index_rate,
|
288 |
+
if_f0,
|
289 |
+
filter_radius,
|
290 |
+
tgt_sr,
|
291 |
+
resample_sr,
|
292 |
+
rms_mix_rate,
|
293 |
+
version,
|
294 |
+
protect,
|
295 |
+
f0_file=None,
|
296 |
+
):
|
297 |
+
if (
|
298 |
+
file_index != ""
|
299 |
+
# and file_big_npy != ""
|
300 |
+
# and os.path.exists(file_big_npy) == True
|
301 |
+
and os.path.exists(file_index) == True
|
302 |
+
and index_rate != 0
|
303 |
+
):
|
304 |
+
try:
|
305 |
+
index = faiss.read_index(file_index)
|
306 |
+
# big_npy = np.load(file_big_npy)
|
307 |
+
big_npy = index.reconstruct_n(0, index.ntotal)
|
308 |
+
except:
|
309 |
+
traceback.print_exc()
|
310 |
+
index = big_npy = None
|
311 |
+
else:
|
312 |
+
index = big_npy = None
|
313 |
+
logger.warning("File index Not found, set None")
|
314 |
+
|
315 |
+
audio = signal.filtfilt(bh, ah, audio)
|
316 |
+
audio_pad = np.pad(audio, (self.window // 2, self.window // 2), mode="reflect")
|
317 |
+
opt_ts = []
|
318 |
+
if audio_pad.shape[0] > self.t_max:
|
319 |
+
audio_sum = np.zeros_like(audio)
|
320 |
+
for i in range(self.window):
|
321 |
+
audio_sum += audio_pad[i : i - self.window]
|
322 |
+
for t in range(self.t_center, audio.shape[0], self.t_center):
|
323 |
+
opt_ts.append(
|
324 |
+
t
|
325 |
+
- self.t_query
|
326 |
+
+ np.where(
|
327 |
+
np.abs(audio_sum[t - self.t_query : t + self.t_query])
|
328 |
+
== np.abs(audio_sum[t - self.t_query : t + self.t_query]).min()
|
329 |
+
)[0][0]
|
330 |
+
)
|
331 |
+
s = 0
|
332 |
+
audio_opt = []
|
333 |
+
t = None
|
334 |
+
t1 = ttime()
|
335 |
+
audio_pad = np.pad(audio, (self.t_pad, self.t_pad), mode="reflect")
|
336 |
+
p_len = audio_pad.shape[0] // self.window
|
337 |
+
inp_f0 = None
|
338 |
+
if hasattr(f0_file, "name") == True:
|
339 |
+
try:
|
340 |
+
with open(f0_file.name, "r") as f:
|
341 |
+
lines = f.read().strip("\n").split("\n")
|
342 |
+
inp_f0 = []
|
343 |
+
for line in lines:
|
344 |
+
inp_f0.append([float(i) for i in line.split(",")])
|
345 |
+
inp_f0 = np.array(inp_f0, dtype="float32")
|
346 |
+
except:
|
347 |
+
traceback.print_exc()
|
348 |
+
sid = torch.tensor(sid, device=self.device).unsqueeze(0).long()
|
349 |
+
pitch, pitchf = None, None
|
350 |
+
if if_f0 == 1:
|
351 |
+
pitch, pitchf = self.get_f0(
|
352 |
+
input_audio_path,
|
353 |
+
audio_pad,
|
354 |
+
p_len,
|
355 |
+
f0_up_key,
|
356 |
+
f0_method,
|
357 |
+
filter_radius,
|
358 |
+
inp_f0,
|
359 |
+
)
|
360 |
+
pitch = pitch[:p_len]
|
361 |
+
pitchf = pitchf[:p_len]
|
362 |
+
if self.device == "mps":
|
363 |
+
pitchf = pitchf.astype(np.float32)
|
364 |
+
pitch = torch.tensor(pitch, device=self.device).unsqueeze(0).long()
|
365 |
+
pitchf = torch.tensor(pitchf, device=self.device).unsqueeze(0).float()
|
366 |
+
t2 = ttime()
|
367 |
+
times[1] += t2 - t1
|
368 |
+
for t in opt_ts:
|
369 |
+
t = t // self.window * self.window
|
370 |
+
if if_f0 == 1:
|
371 |
+
audio_opt.append(
|
372 |
+
self.vc(
|
373 |
+
model,
|
374 |
+
net_g,
|
375 |
+
sid,
|
376 |
+
audio_pad[s : t + self.t_pad2 + self.window],
|
377 |
+
pitch[:, s // self.window : (t + self.t_pad2) // self.window],
|
378 |
+
pitchf[:, s // self.window : (t + self.t_pad2) // self.window],
|
379 |
+
times,
|
380 |
+
index,
|
381 |
+
big_npy,
|
382 |
+
index_rate,
|
383 |
+
version,
|
384 |
+
protect,
|
385 |
+
)[self.t_pad_tgt : -self.t_pad_tgt]
|
386 |
+
)
|
387 |
+
else:
|
388 |
+
audio_opt.append(
|
389 |
+
self.vc(
|
390 |
+
model,
|
391 |
+
net_g,
|
392 |
+
sid,
|
393 |
+
audio_pad[s : t + self.t_pad2 + self.window],
|
394 |
+
None,
|
395 |
+
None,
|
396 |
+
times,
|
397 |
+
index,
|
398 |
+
big_npy,
|
399 |
+
index_rate,
|
400 |
+
version,
|
401 |
+
protect,
|
402 |
+
)[self.t_pad_tgt : -self.t_pad_tgt]
|
403 |
+
)
|
404 |
+
s = t
|
405 |
+
if if_f0 == 1:
|
406 |
+
audio_opt.append(
|
407 |
+
self.vc(
|
408 |
+
model,
|
409 |
+
net_g,
|
410 |
+
sid,
|
411 |
+
audio_pad[t:],
|
412 |
+
pitch[:, t // self.window :] if t is not None else pitch,
|
413 |
+
pitchf[:, t // self.window :] if t is not None else pitchf,
|
414 |
+
times,
|
415 |
+
index,
|
416 |
+
big_npy,
|
417 |
+
index_rate,
|
418 |
+
version,
|
419 |
+
protect,
|
420 |
+
)[self.t_pad_tgt : -self.t_pad_tgt]
|
421 |
+
)
|
422 |
+
else:
|
423 |
+
audio_opt.append(
|
424 |
+
self.vc(
|
425 |
+
model,
|
426 |
+
net_g,
|
427 |
+
sid,
|
428 |
+
audio_pad[t:],
|
429 |
+
None,
|
430 |
+
None,
|
431 |
+
times,
|
432 |
+
index,
|
433 |
+
big_npy,
|
434 |
+
index_rate,
|
435 |
+
version,
|
436 |
+
protect,
|
437 |
+
)[self.t_pad_tgt : -self.t_pad_tgt]
|
438 |
+
)
|
439 |
+
audio_opt = np.concatenate(audio_opt)
|
440 |
+
if rms_mix_rate != 1:
|
441 |
+
audio_opt = change_rms(audio, 16000, audio_opt, tgt_sr, rms_mix_rate)
|
442 |
+
if resample_sr >= 16000 and tgt_sr != resample_sr:
|
443 |
+
audio_opt = librosa.resample(
|
444 |
+
audio_opt, orig_sr=tgt_sr, target_sr=resample_sr
|
445 |
+
)
|
446 |
+
audio_max = np.abs(audio_opt).max() / 0.99
|
447 |
+
max_int16 = 32768
|
448 |
+
if audio_max > 1:
|
449 |
+
max_int16 /= audio_max
|
450 |
+
audio_opt = (audio_opt * max_int16).astype(np.int16)
|
451 |
+
del pitch, pitchf, sid
|
452 |
+
if torch.cuda.is_available():
|
453 |
+
torch.cuda.empty_cache()
|
454 |
+
return audio_opt
|
voice_main.py
ADDED
@@ -0,0 +1,732 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from soni_translate.logging_setup import logger
|
2 |
+
import torch
|
3 |
+
import gc
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
import shutil
|
7 |
+
import warnings
|
8 |
+
import threading
|
9 |
+
from tqdm import tqdm
|
10 |
+
from lib.infer_pack.models import (
|
11 |
+
SynthesizerTrnMs256NSFsid,
|
12 |
+
SynthesizerTrnMs256NSFsid_nono,
|
13 |
+
SynthesizerTrnMs768NSFsid,
|
14 |
+
SynthesizerTrnMs768NSFsid_nono,
|
15 |
+
)
|
16 |
+
from lib.audio import load_audio
|
17 |
+
import soundfile as sf
|
18 |
+
import edge_tts
|
19 |
+
import asyncio
|
20 |
+
from soni_translate.utils import remove_directory_contents, create_directories
|
21 |
+
from scipy import signal
|
22 |
+
from time import time as ttime
|
23 |
+
import faiss
|
24 |
+
from vci_pipeline import VC, change_rms, bh, ah
|
25 |
+
import librosa
|
26 |
+
|
27 |
+
warnings.filterwarnings("ignore")
|
28 |
+
|
29 |
+
|
30 |
+
class Config:
|
31 |
+
def __init__(self, only_cpu=False):
|
32 |
+
self.device = "cuda:0"
|
33 |
+
self.is_half = True
|
34 |
+
self.n_cpu = 0
|
35 |
+
self.gpu_name = None
|
36 |
+
self.gpu_mem = None
|
37 |
+
(
|
38 |
+
self.x_pad,
|
39 |
+
self.x_query,
|
40 |
+
self.x_center,
|
41 |
+
self.x_max
|
42 |
+
) = self.device_config(only_cpu)
|
43 |
+
|
44 |
+
def device_config(self, only_cpu) -> tuple:
|
45 |
+
if torch.cuda.is_available() and not only_cpu:
|
46 |
+
i_device = int(self.device.split(":")[-1])
|
47 |
+
self.gpu_name = torch.cuda.get_device_name(i_device)
|
48 |
+
if (
|
49 |
+
("16" in self.gpu_name and "V100" not in self.gpu_name.upper())
|
50 |
+
or "P40" in self.gpu_name.upper()
|
51 |
+
or "1060" in self.gpu_name
|
52 |
+
or "1070" in self.gpu_name
|
53 |
+
or "1080" in self.gpu_name
|
54 |
+
):
|
55 |
+
logger.info(
|
56 |
+
"16/10 Series GPUs and P40 excel "
|
57 |
+
"in single-precision tasks."
|
58 |
+
)
|
59 |
+
self.is_half = False
|
60 |
+
else:
|
61 |
+
self.gpu_name = None
|
62 |
+
self.gpu_mem = int(
|
63 |
+
torch.cuda.get_device_properties(i_device).total_memory
|
64 |
+
/ 1024
|
65 |
+
/ 1024
|
66 |
+
/ 1024
|
67 |
+
+ 0.4
|
68 |
+
)
|
69 |
+
elif torch.backends.mps.is_available() and not only_cpu:
|
70 |
+
logger.info("Supported N-card not found, using MPS for inference")
|
71 |
+
self.device = "mps"
|
72 |
+
else:
|
73 |
+
logger.info("No supported N-card found, using CPU for inference")
|
74 |
+
self.device = "cpu"
|
75 |
+
self.is_half = False
|
76 |
+
|
77 |
+
if self.n_cpu == 0:
|
78 |
+
self.n_cpu = os.cpu_count()
|
79 |
+
|
80 |
+
if self.is_half:
|
81 |
+
# 6GB VRAM configuration
|
82 |
+
x_pad = 3
|
83 |
+
x_query = 10
|
84 |
+
x_center = 60
|
85 |
+
x_max = 65
|
86 |
+
else:
|
87 |
+
# 5GB VRAM configuration
|
88 |
+
x_pad = 1
|
89 |
+
x_query = 6
|
90 |
+
x_center = 38
|
91 |
+
x_max = 41
|
92 |
+
|
93 |
+
if self.gpu_mem is not None and self.gpu_mem <= 4:
|
94 |
+
x_pad = 1
|
95 |
+
x_query = 5
|
96 |
+
x_center = 30
|
97 |
+
x_max = 32
|
98 |
+
|
99 |
+
logger.info(
|
100 |
+
f"Config: Device is {self.device}, "
|
101 |
+
f"half precision is {self.is_half}"
|
102 |
+
)
|
103 |
+
|
104 |
+
return x_pad, x_query, x_center, x_max
|
105 |
+
|
106 |
+
|
107 |
+
BASE_DOWNLOAD_LINK = "https://huggingface.co/r3gm/sonitranslate_voice_models/resolve/main/"
|
108 |
+
BASE_MODELS = [
|
109 |
+
"hubert_base.pt",
|
110 |
+
"rmvpe.pt"
|
111 |
+
]
|
112 |
+
BASE_DIR = "."
|
113 |
+
|
114 |
+
|
115 |
+
def load_hu_bert(config):
|
116 |
+
from fairseq import checkpoint_utils
|
117 |
+
from soni_translate.utils import download_manager
|
118 |
+
|
119 |
+
for id_model in BASE_MODELS:
|
120 |
+
download_manager(
|
121 |
+
os.path.join(BASE_DOWNLOAD_LINK, id_model), BASE_DIR
|
122 |
+
)
|
123 |
+
|
124 |
+
models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
|
125 |
+
["hubert_base.pt"],
|
126 |
+
suffix="",
|
127 |
+
)
|
128 |
+
hubert_model = models[0]
|
129 |
+
hubert_model = hubert_model.to(config.device)
|
130 |
+
if config.is_half:
|
131 |
+
hubert_model = hubert_model.half()
|
132 |
+
else:
|
133 |
+
hubert_model = hubert_model.float()
|
134 |
+
hubert_model.eval()
|
135 |
+
|
136 |
+
return hubert_model
|
137 |
+
|
138 |
+
|
139 |
+
def load_trained_model(model_path, config):
|
140 |
+
|
141 |
+
if not model_path:
|
142 |
+
raise ValueError("No model found")
|
143 |
+
|
144 |
+
logger.info("Loading %s" % model_path)
|
145 |
+
cpt = torch.load(model_path, map_location="cpu")
|
146 |
+
tgt_sr = cpt["config"][-1]
|
147 |
+
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
|
148 |
+
if_f0 = cpt.get("f0", 1)
|
149 |
+
if if_f0 == 0:
|
150 |
+
# protect to 0.5 need?
|
151 |
+
pass
|
152 |
+
|
153 |
+
version = cpt.get("version", "v1")
|
154 |
+
if version == "v1":
|
155 |
+
if if_f0 == 1:
|
156 |
+
net_g = SynthesizerTrnMs256NSFsid(
|
157 |
+
*cpt["config"], is_half=config.is_half
|
158 |
+
)
|
159 |
+
else:
|
160 |
+
net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
|
161 |
+
elif version == "v2":
|
162 |
+
if if_f0 == 1:
|
163 |
+
net_g = SynthesizerTrnMs768NSFsid(
|
164 |
+
*cpt["config"], is_half=config.is_half
|
165 |
+
)
|
166 |
+
else:
|
167 |
+
net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
|
168 |
+
del net_g.enc_q
|
169 |
+
|
170 |
+
net_g.load_state_dict(cpt["weight"], strict=False)
|
171 |
+
net_g.eval().to(config.device)
|
172 |
+
|
173 |
+
if config.is_half:
|
174 |
+
net_g = net_g.half()
|
175 |
+
else:
|
176 |
+
net_g = net_g.float()
|
177 |
+
|
178 |
+
vc = VC(tgt_sr, config)
|
179 |
+
n_spk = cpt["config"][-3]
|
180 |
+
|
181 |
+
return n_spk, tgt_sr, net_g, vc, cpt, version
|
182 |
+
|
183 |
+
|
184 |
+
class ClassVoices:
|
185 |
+
def __init__(self, only_cpu=False):
|
186 |
+
self.model_config = {}
|
187 |
+
self.config = None
|
188 |
+
self.only_cpu = only_cpu
|
189 |
+
|
190 |
+
def apply_conf(
|
191 |
+
self,
|
192 |
+
tag="base_model",
|
193 |
+
file_model="",
|
194 |
+
pitch_algo="pm",
|
195 |
+
pitch_lvl=0,
|
196 |
+
file_index="",
|
197 |
+
index_influence=0.66,
|
198 |
+
respiration_median_filtering=3,
|
199 |
+
envelope_ratio=0.25,
|
200 |
+
consonant_breath_protection=0.33,
|
201 |
+
resample_sr=0,
|
202 |
+
file_pitch_algo="",
|
203 |
+
):
|
204 |
+
|
205 |
+
if not file_model:
|
206 |
+
raise ValueError("Model not found")
|
207 |
+
|
208 |
+
if file_index is None:
|
209 |
+
file_index = ""
|
210 |
+
|
211 |
+
if file_pitch_algo is None:
|
212 |
+
file_pitch_algo = ""
|
213 |
+
|
214 |
+
if not self.config:
|
215 |
+
self.config = Config(self.only_cpu)
|
216 |
+
self.hu_bert_model = None
|
217 |
+
self.model_pitch_estimator = None
|
218 |
+
|
219 |
+
self.model_config[tag] = {
|
220 |
+
"file_model": file_model,
|
221 |
+
"pitch_algo": pitch_algo,
|
222 |
+
"pitch_lvl": pitch_lvl, # no decimal
|
223 |
+
"file_index": file_index,
|
224 |
+
"index_influence": index_influence,
|
225 |
+
"respiration_median_filtering": respiration_median_filtering,
|
226 |
+
"envelope_ratio": envelope_ratio,
|
227 |
+
"consonant_breath_protection": consonant_breath_protection,
|
228 |
+
"resample_sr": resample_sr,
|
229 |
+
"file_pitch_algo": file_pitch_algo,
|
230 |
+
}
|
231 |
+
return f"CONFIGURATION APPLIED FOR {tag}: {file_model}"
|
232 |
+
|
233 |
+
def infer(
|
234 |
+
self,
|
235 |
+
task_id,
|
236 |
+
params,
|
237 |
+
# load model
|
238 |
+
n_spk,
|
239 |
+
tgt_sr,
|
240 |
+
net_g,
|
241 |
+
pipe,
|
242 |
+
cpt,
|
243 |
+
version,
|
244 |
+
if_f0,
|
245 |
+
# load index
|
246 |
+
index_rate,
|
247 |
+
index,
|
248 |
+
big_npy,
|
249 |
+
# load f0 file
|
250 |
+
inp_f0,
|
251 |
+
# audio file
|
252 |
+
input_audio_path,
|
253 |
+
overwrite,
|
254 |
+
):
|
255 |
+
|
256 |
+
f0_method = params["pitch_algo"]
|
257 |
+
f0_up_key = params["pitch_lvl"]
|
258 |
+
filter_radius = params["respiration_median_filtering"]
|
259 |
+
resample_sr = params["resample_sr"]
|
260 |
+
rms_mix_rate = params["envelope_ratio"]
|
261 |
+
protect = params["consonant_breath_protection"]
|
262 |
+
|
263 |
+
if not os.path.exists(input_audio_path):
|
264 |
+
raise ValueError(
|
265 |
+
"The audio file was not found or is not "
|
266 |
+
f"a valid file: {input_audio_path}"
|
267 |
+
)
|
268 |
+
|
269 |
+
f0_up_key = int(f0_up_key)
|
270 |
+
|
271 |
+
audio = load_audio(input_audio_path, 16000)
|
272 |
+
|
273 |
+
# Normalize audio
|
274 |
+
audio_max = np.abs(audio).max() / 0.95
|
275 |
+
if audio_max > 1:
|
276 |
+
audio /= audio_max
|
277 |
+
|
278 |
+
times = [0, 0, 0]
|
279 |
+
|
280 |
+
# filters audio signal, pads it, computes sliding window sums,
|
281 |
+
# and extracts optimized time indices
|
282 |
+
audio = signal.filtfilt(bh, ah, audio)
|
283 |
+
audio_pad = np.pad(
|
284 |
+
audio, (pipe.window // 2, pipe.window // 2), mode="reflect"
|
285 |
+
)
|
286 |
+
opt_ts = []
|
287 |
+
if audio_pad.shape[0] > pipe.t_max:
|
288 |
+
audio_sum = np.zeros_like(audio)
|
289 |
+
for i in range(pipe.window):
|
290 |
+
audio_sum += audio_pad[i:i - pipe.window]
|
291 |
+
for t in range(pipe.t_center, audio.shape[0], pipe.t_center):
|
292 |
+
opt_ts.append(
|
293 |
+
t
|
294 |
+
- pipe.t_query
|
295 |
+
+ np.where(
|
296 |
+
np.abs(audio_sum[t - pipe.t_query: t + pipe.t_query])
|
297 |
+
== np.abs(audio_sum[t - pipe.t_query: t + pipe.t_query]).min()
|
298 |
+
)[0][0]
|
299 |
+
)
|
300 |
+
|
301 |
+
s = 0
|
302 |
+
audio_opt = []
|
303 |
+
t = None
|
304 |
+
t1 = ttime()
|
305 |
+
|
306 |
+
sid_value = 0
|
307 |
+
sid = torch.tensor(sid_value, device=pipe.device).unsqueeze(0).long()
|
308 |
+
|
309 |
+
# Pads audio symmetrically, calculates length divided by window size.
|
310 |
+
audio_pad = np.pad(audio, (pipe.t_pad, pipe.t_pad), mode="reflect")
|
311 |
+
p_len = audio_pad.shape[0] // pipe.window
|
312 |
+
|
313 |
+
# Estimates pitch from audio signal
|
314 |
+
pitch, pitchf = None, None
|
315 |
+
if if_f0 == 1:
|
316 |
+
pitch, pitchf = pipe.get_f0(
|
317 |
+
input_audio_path,
|
318 |
+
audio_pad,
|
319 |
+
p_len,
|
320 |
+
f0_up_key,
|
321 |
+
f0_method,
|
322 |
+
filter_radius,
|
323 |
+
inp_f0,
|
324 |
+
)
|
325 |
+
pitch = pitch[:p_len]
|
326 |
+
pitchf = pitchf[:p_len]
|
327 |
+
if pipe.device == "mps":
|
328 |
+
pitchf = pitchf.astype(np.float32)
|
329 |
+
pitch = torch.tensor(
|
330 |
+
pitch, device=pipe.device
|
331 |
+
).unsqueeze(0).long()
|
332 |
+
pitchf = torch.tensor(
|
333 |
+
pitchf, device=pipe.device
|
334 |
+
).unsqueeze(0).float()
|
335 |
+
|
336 |
+
t2 = ttime()
|
337 |
+
times[1] += t2 - t1
|
338 |
+
for t in opt_ts:
|
339 |
+
t = t // pipe.window * pipe.window
|
340 |
+
if if_f0 == 1:
|
341 |
+
pitch_slice = pitch[
|
342 |
+
:, s // pipe.window: (t + pipe.t_pad2) // pipe.window
|
343 |
+
]
|
344 |
+
pitchf_slice = pitchf[
|
345 |
+
:, s // pipe.window: (t + pipe.t_pad2) // pipe.window
|
346 |
+
]
|
347 |
+
else:
|
348 |
+
pitch_slice = None
|
349 |
+
pitchf_slice = None
|
350 |
+
|
351 |
+
audio_slice = audio_pad[s:t + pipe.t_pad2 + pipe.window]
|
352 |
+
audio_opt.append(
|
353 |
+
pipe.vc(
|
354 |
+
self.hu_bert_model,
|
355 |
+
net_g,
|
356 |
+
sid,
|
357 |
+
audio_slice,
|
358 |
+
pitch_slice,
|
359 |
+
pitchf_slice,
|
360 |
+
times,
|
361 |
+
index,
|
362 |
+
big_npy,
|
363 |
+
index_rate,
|
364 |
+
version,
|
365 |
+
protect,
|
366 |
+
)[pipe.t_pad_tgt:-pipe.t_pad_tgt]
|
367 |
+
)
|
368 |
+
s = t
|
369 |
+
|
370 |
+
pitch_end_slice = pitch[
|
371 |
+
:, t // pipe.window:
|
372 |
+
] if t is not None else pitch
|
373 |
+
pitchf_end_slice = pitchf[
|
374 |
+
:, t // pipe.window:
|
375 |
+
] if t is not None else pitchf
|
376 |
+
|
377 |
+
audio_opt.append(
|
378 |
+
pipe.vc(
|
379 |
+
self.hu_bert_model,
|
380 |
+
net_g,
|
381 |
+
sid,
|
382 |
+
audio_pad[t:],
|
383 |
+
pitch_end_slice,
|
384 |
+
pitchf_end_slice,
|
385 |
+
times,
|
386 |
+
index,
|
387 |
+
big_npy,
|
388 |
+
index_rate,
|
389 |
+
version,
|
390 |
+
protect,
|
391 |
+
)[pipe.t_pad_tgt:-pipe.t_pad_tgt]
|
392 |
+
)
|
393 |
+
|
394 |
+
audio_opt = np.concatenate(audio_opt)
|
395 |
+
if rms_mix_rate != 1:
|
396 |
+
audio_opt = change_rms(
|
397 |
+
audio, 16000, audio_opt, tgt_sr, rms_mix_rate
|
398 |
+
)
|
399 |
+
if resample_sr >= 16000 and tgt_sr != resample_sr:
|
400 |
+
audio_opt = librosa.resample(
|
401 |
+
audio_opt, orig_sr=tgt_sr, target_sr=resample_sr
|
402 |
+
)
|
403 |
+
audio_max = np.abs(audio_opt).max() / 0.99
|
404 |
+
max_int16 = 32768
|
405 |
+
if audio_max > 1:
|
406 |
+
max_int16 /= audio_max
|
407 |
+
audio_opt = (audio_opt * max_int16).astype(np.int16)
|
408 |
+
del pitch, pitchf, sid
|
409 |
+
if torch.cuda.is_available():
|
410 |
+
torch.cuda.empty_cache()
|
411 |
+
|
412 |
+
if tgt_sr != resample_sr >= 16000:
|
413 |
+
final_sr = resample_sr
|
414 |
+
else:
|
415 |
+
final_sr = tgt_sr
|
416 |
+
|
417 |
+
"""
|
418 |
+
"Success.\n %s\nTime:\n npy:%ss, f0:%ss, infer:%ss" % (
|
419 |
+
times[0],
|
420 |
+
times[1],
|
421 |
+
times[2],
|
422 |
+
), (final_sr, audio_opt)
|
423 |
+
|
424 |
+
"""
|
425 |
+
|
426 |
+
if overwrite:
|
427 |
+
output_audio_path = input_audio_path # Overwrite
|
428 |
+
else:
|
429 |
+
basename = os.path.basename(input_audio_path)
|
430 |
+
dirname = os.path.dirname(input_audio_path)
|
431 |
+
|
432 |
+
new_basename = basename.split(
|
433 |
+
'.')[0] + "_edited." + basename.split('.')[-1]
|
434 |
+
new_path = os.path.join(dirname, new_basename)
|
435 |
+
logger.info(str(new_path))
|
436 |
+
|
437 |
+
output_audio_path = new_path
|
438 |
+
|
439 |
+
# Save file
|
440 |
+
sf.write(
|
441 |
+
file=output_audio_path,
|
442 |
+
samplerate=final_sr,
|
443 |
+
data=audio_opt
|
444 |
+
)
|
445 |
+
|
446 |
+
self.model_config[task_id]["result"].append(output_audio_path)
|
447 |
+
self.output_list.append(output_audio_path)
|
448 |
+
|
449 |
+
def make_test(
|
450 |
+
self,
|
451 |
+
tts_text,
|
452 |
+
tts_voice,
|
453 |
+
model_path,
|
454 |
+
index_path,
|
455 |
+
transpose,
|
456 |
+
f0_method,
|
457 |
+
):
|
458 |
+
|
459 |
+
folder_test = "test"
|
460 |
+
tag = "test_edge"
|
461 |
+
tts_file = "test/test.wav"
|
462 |
+
tts_edited = "test/test_edited.wav"
|
463 |
+
|
464 |
+
create_directories(folder_test)
|
465 |
+
remove_directory_contents(folder_test)
|
466 |
+
|
467 |
+
if "SET_LIMIT" == os.getenv("DEMO"):
|
468 |
+
if len(tts_text) > 60:
|
469 |
+
tts_text = tts_text[:60]
|
470 |
+
logger.warning("DEMO; limit to 60 characters")
|
471 |
+
|
472 |
+
try:
|
473 |
+
asyncio.run(edge_tts.Communicate(
|
474 |
+
tts_text, "-".join(tts_voice.split('-')[:-1])
|
475 |
+
).save(tts_file))
|
476 |
+
except Exception as e:
|
477 |
+
raise ValueError(
|
478 |
+
"No audio was received. Please change the "
|
479 |
+
f"tts voice for {tts_voice}. Error: {str(e)}"
|
480 |
+
)
|
481 |
+
|
482 |
+
shutil.copy(tts_file, tts_edited)
|
483 |
+
|
484 |
+
self.apply_conf(
|
485 |
+
tag=tag,
|
486 |
+
file_model=model_path,
|
487 |
+
pitch_algo=f0_method,
|
488 |
+
pitch_lvl=transpose,
|
489 |
+
file_index=index_path,
|
490 |
+
index_influence=0.66,
|
491 |
+
respiration_median_filtering=3,
|
492 |
+
envelope_ratio=0.25,
|
493 |
+
consonant_breath_protection=0.33,
|
494 |
+
)
|
495 |
+
|
496 |
+
self(
|
497 |
+
audio_files=tts_edited,
|
498 |
+
tag_list=tag,
|
499 |
+
overwrite=True
|
500 |
+
)
|
501 |
+
|
502 |
+
return tts_edited, tts_file
|
503 |
+
|
504 |
+
def run_threads(self, threads):
|
505 |
+
# Start threads
|
506 |
+
for thread in threads:
|
507 |
+
thread.start()
|
508 |
+
|
509 |
+
# Wait for all threads to finish
|
510 |
+
for thread in threads:
|
511 |
+
thread.join()
|
512 |
+
|
513 |
+
gc.collect()
|
514 |
+
torch.cuda.empty_cache()
|
515 |
+
|
516 |
+
def unload_models(self):
|
517 |
+
self.hu_bert_model = None
|
518 |
+
self.model_pitch_estimator = None
|
519 |
+
gc.collect()
|
520 |
+
torch.cuda.empty_cache()
|
521 |
+
|
522 |
+
def __call__(
|
523 |
+
self,
|
524 |
+
audio_files=[],
|
525 |
+
tag_list=[],
|
526 |
+
overwrite=False,
|
527 |
+
parallel_workers=1,
|
528 |
+
):
|
529 |
+
logger.info(f"Parallel workers: {str(parallel_workers)}")
|
530 |
+
|
531 |
+
self.output_list = []
|
532 |
+
|
533 |
+
if not self.model_config:
|
534 |
+
raise ValueError("No model has been configured for inference")
|
535 |
+
|
536 |
+
if isinstance(audio_files, str):
|
537 |
+
audio_files = [audio_files]
|
538 |
+
if isinstance(tag_list, str):
|
539 |
+
tag_list = [tag_list]
|
540 |
+
|
541 |
+
if not audio_files:
|
542 |
+
raise ValueError("No audio found to convert")
|
543 |
+
if not tag_list:
|
544 |
+
tag_list = [list(self.model_config.keys())[-1]] * len(audio_files)
|
545 |
+
|
546 |
+
if len(audio_files) > len(tag_list):
|
547 |
+
logger.info("Extend tag list to match audio files")
|
548 |
+
extend_number = len(audio_files) - len(tag_list)
|
549 |
+
tag_list.extend([tag_list[0]] * extend_number)
|
550 |
+
|
551 |
+
if len(audio_files) < len(tag_list):
|
552 |
+
logger.info("Cut list tags")
|
553 |
+
tag_list = tag_list[:len(audio_files)]
|
554 |
+
|
555 |
+
tag_file_pairs = list(zip(tag_list, audio_files))
|
556 |
+
sorted_tag_file = sorted(tag_file_pairs, key=lambda x: x[0])
|
557 |
+
|
558 |
+
# Base params
|
559 |
+
if not self.hu_bert_model:
|
560 |
+
self.hu_bert_model = load_hu_bert(self.config)
|
561 |
+
|
562 |
+
cache_params = None
|
563 |
+
threads = []
|
564 |
+
progress_bar = tqdm(total=len(tag_list), desc="Progress")
|
565 |
+
for i, (id_tag, input_audio_path) in enumerate(sorted_tag_file):
|
566 |
+
|
567 |
+
if id_tag not in self.model_config.keys():
|
568 |
+
logger.info(
|
569 |
+
f"No configured model for {id_tag} with {input_audio_path}"
|
570 |
+
)
|
571 |
+
continue
|
572 |
+
|
573 |
+
if (
|
574 |
+
len(threads) >= parallel_workers
|
575 |
+
or cache_params != id_tag
|
576 |
+
and cache_params is not None
|
577 |
+
):
|
578 |
+
|
579 |
+
self.run_threads(threads)
|
580 |
+
progress_bar.update(len(threads))
|
581 |
+
|
582 |
+
threads = []
|
583 |
+
|
584 |
+
if cache_params != id_tag:
|
585 |
+
|
586 |
+
self.model_config[id_tag]["result"] = []
|
587 |
+
|
588 |
+
# Unload previous
|
589 |
+
(
|
590 |
+
n_spk,
|
591 |
+
tgt_sr,
|
592 |
+
net_g,
|
593 |
+
pipe,
|
594 |
+
cpt,
|
595 |
+
version,
|
596 |
+
if_f0,
|
597 |
+
index_rate,
|
598 |
+
index,
|
599 |
+
big_npy,
|
600 |
+
inp_f0,
|
601 |
+
) = [None] * 11
|
602 |
+
gc.collect()
|
603 |
+
torch.cuda.empty_cache()
|
604 |
+
|
605 |
+
# Model params
|
606 |
+
params = self.model_config[id_tag]
|
607 |
+
|
608 |
+
model_path = params["file_model"]
|
609 |
+
f0_method = params["pitch_algo"]
|
610 |
+
file_index = params["file_index"]
|
611 |
+
index_rate = params["index_influence"]
|
612 |
+
f0_file = params["file_pitch_algo"]
|
613 |
+
|
614 |
+
# Load model
|
615 |
+
(
|
616 |
+
n_spk,
|
617 |
+
tgt_sr,
|
618 |
+
net_g,
|
619 |
+
pipe,
|
620 |
+
cpt,
|
621 |
+
version
|
622 |
+
) = load_trained_model(model_path, self.config)
|
623 |
+
if_f0 = cpt.get("f0", 1) # pitch data
|
624 |
+
|
625 |
+
# Load index
|
626 |
+
if os.path.exists(file_index) and index_rate != 0:
|
627 |
+
try:
|
628 |
+
index = faiss.read_index(file_index)
|
629 |
+
big_npy = index.reconstruct_n(0, index.ntotal)
|
630 |
+
except Exception as error:
|
631 |
+
logger.error(f"Index: {str(error)}")
|
632 |
+
index_rate = 0
|
633 |
+
index = big_npy = None
|
634 |
+
else:
|
635 |
+
logger.warning("File index not found")
|
636 |
+
index_rate = 0
|
637 |
+
index = big_npy = None
|
638 |
+
|
639 |
+
# Load f0 file
|
640 |
+
inp_f0 = None
|
641 |
+
if os.path.exists(f0_file):
|
642 |
+
try:
|
643 |
+
with open(f0_file, "r") as f:
|
644 |
+
lines = f.read().strip("\n").split("\n")
|
645 |
+
inp_f0 = []
|
646 |
+
for line in lines:
|
647 |
+
inp_f0.append([float(i) for i in line.split(",")])
|
648 |
+
inp_f0 = np.array(inp_f0, dtype="float32")
|
649 |
+
except Exception as error:
|
650 |
+
logger.error(f"f0 file: {str(error)}")
|
651 |
+
|
652 |
+
if "rmvpe" in f0_method:
|
653 |
+
if not self.model_pitch_estimator:
|
654 |
+
from lib.rmvpe import RMVPE
|
655 |
+
|
656 |
+
logger.info("Loading vocal pitch estimator model")
|
657 |
+
self.model_pitch_estimator = RMVPE(
|
658 |
+
"rmvpe.pt",
|
659 |
+
is_half=self.config.is_half,
|
660 |
+
device=self.config.device
|
661 |
+
)
|
662 |
+
|
663 |
+
pipe.model_rmvpe = self.model_pitch_estimator
|
664 |
+
|
665 |
+
cache_params = id_tag
|
666 |
+
|
667 |
+
# self.infer(
|
668 |
+
# id_tag,
|
669 |
+
# params,
|
670 |
+
# # load model
|
671 |
+
# n_spk,
|
672 |
+
# tgt_sr,
|
673 |
+
# net_g,
|
674 |
+
# pipe,
|
675 |
+
# cpt,
|
676 |
+
# version,
|
677 |
+
# if_f0,
|
678 |
+
# # load index
|
679 |
+
# index_rate,
|
680 |
+
# index,
|
681 |
+
# big_npy,
|
682 |
+
# # load f0 file
|
683 |
+
# inp_f0,
|
684 |
+
# # output file
|
685 |
+
# input_audio_path,
|
686 |
+
# overwrite,
|
687 |
+
# )
|
688 |
+
|
689 |
+
thread = threading.Thread(
|
690 |
+
target=self.infer,
|
691 |
+
args=(
|
692 |
+
id_tag,
|
693 |
+
params,
|
694 |
+
# loaded model
|
695 |
+
n_spk,
|
696 |
+
tgt_sr,
|
697 |
+
net_g,
|
698 |
+
pipe,
|
699 |
+
cpt,
|
700 |
+
version,
|
701 |
+
if_f0,
|
702 |
+
# loaded index
|
703 |
+
index_rate,
|
704 |
+
index,
|
705 |
+
big_npy,
|
706 |
+
# loaded f0 file
|
707 |
+
inp_f0,
|
708 |
+
# audio file
|
709 |
+
input_audio_path,
|
710 |
+
overwrite,
|
711 |
+
)
|
712 |
+
)
|
713 |
+
|
714 |
+
threads.append(thread)
|
715 |
+
|
716 |
+
# Run last
|
717 |
+
if threads:
|
718 |
+
self.run_threads(threads)
|
719 |
+
|
720 |
+
progress_bar.update(len(threads))
|
721 |
+
progress_bar.close()
|
722 |
+
|
723 |
+
final_result = []
|
724 |
+
valid_tags = set(tag_list)
|
725 |
+
for tag in valid_tags:
|
726 |
+
if (
|
727 |
+
tag in self.model_config.keys()
|
728 |
+
and "result" in self.model_config[tag].keys()
|
729 |
+
):
|
730 |
+
final_result.extend(self.model_config[tag]["result"])
|
731 |
+
|
732 |
+
return final_result
|