r3gm commited on
Commit
b152010
·
1 Parent(s): 62b1e34
.gitignore ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib64/
18
+ parts/
19
+ sdist/
20
+ var/
21
+ wheels/
22
+ share/python-wheels/
23
+ *.egg-info/
24
+ .installed.cfg
25
+ *.egg
26
+ MANIFEST
27
+
28
+ # PyInstaller
29
+ # Usually these files are written by a python script from a template
30
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
31
+ *.manifest
32
+ *.spec
33
+
34
+ # Installer logs
35
+ pip-log.txt
36
+ pip-delete-this-directory.txt
37
+
38
+ # Unit test / coverage reports
39
+ htmlcov/
40
+ .tox/
41
+ .nox/
42
+ .coverage
43
+ .coverage.*
44
+ .cache
45
+ nosetests.xml
46
+ coverage.xml
47
+ *.cover
48
+ *.py,cover
49
+ .hypothesis/
50
+ .pytest_cache/
51
+ cover/
52
+
53
+ # Translations
54
+ *.mo
55
+ *.pot
56
+
57
+ # Django stuff:
58
+ *.log
59
+ local_settings.py
60
+ db.sqlite3
61
+ db.sqlite3-journal
62
+
63
+ # Flask stuff:
64
+ instance/
65
+ .webassets-cache
66
+
67
+ # Scrapy stuff:
68
+ .scrapy
69
+
70
+ # Sphinx documentation
71
+ docs/_build/
72
+
73
+ # PyBuilder
74
+ .pybuilder/
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ # For a library or package, you might want to ignore these files since the code is
86
+ # intended to run in multiple environments; otherwise, check them in:
87
+ # .python-version
88
+
89
+ # pipenv
90
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
92
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
93
+ # install all needed dependencies.
94
+ #Pipfile.lock
95
+
96
+ # poetry
97
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
98
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
99
+ # commonly ignored for libraries.
100
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
101
+ #poetry.lock
102
+
103
+ # pdm
104
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
105
+ #pdm.lock
106
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
107
+ # in version control.
108
+ # https://pdm.fming.dev/#use-with-ide
109
+ .pdm.toml
110
+
111
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
112
+ __pypackages__/
113
+
114
+ # Celery stuff
115
+ celerybeat-schedule
116
+ celerybeat.pid
117
+
118
+ # SageMath parsed files
119
+ *.sage.py
120
+
121
+ # Environments
122
+ .env
123
+ .venv
124
+ env/
125
+ venv/
126
+ ENV/
127
+ env.bak/
128
+ venv.bak/
129
+
130
+ # Spyder project settings
131
+ .spyderproject
132
+ .spyproject
133
+
134
+ # Rope project settings
135
+ .ropeproject
136
+
137
+ # mkdocs documentation
138
+ /site
139
+
140
+ # mypy
141
+ .mypy_cache/
142
+ .dmypy.json
143
+ dmypy.json
144
+
145
+ # Pyre type checker
146
+ .pyre/
147
+
148
+ # pytype static type analyzer
149
+ .pytype/
150
+
151
+ # Cython debug symbols
152
+ cython_debug/
153
+
154
+ # PyCharm
155
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
156
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
157
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
158
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
159
+ #.idea/
160
+
161
+ # Ignore
162
+ sub_tra.*
163
+ sub_ori.*
164
+ SPEAKER_00.*
165
+ SPEAKER_01.*
166
+ SPEAKER_02.*
167
+ SPEAKER_03.*
168
+ SPEAKER_04.*
169
+ SPEAKER_05.*
170
+ SPEAKER_06.*
171
+ SPEAKER_07.*
172
+ SPEAKER_08.*
173
+ SPEAKER_09.*
174
+ SPEAKER_10.*
175
+ SPEAKER_11.*
176
+ task_subtitle.*
177
+ *.mp3
178
+ *.mp4
179
+ *.ogg
180
+ *.wav
181
+ *.mkv
182
+ *.webm
183
+ *.avi
184
+ *.mpg
185
+ *.mov
186
+ *.ogv
187
+ *.wmv
188
+ test.py
189
+ list.txt
190
+ text_preprocessor.txt
191
+ text_translation.txt
192
+ *.srt
193
+ *.vtt
194
+ *.tsv
195
+ *.aud
196
+ *.ass
197
+ *.pt
198
+ .vscode/
199
+ mdx_models/*.onnx
200
+ _XTTS_/
201
+ downloads/
202
+ logs/
203
+ weights/
204
+ clean_song_output/
205
+ audio2/
206
+ audio/
207
+ outputs/
208
+ processed/
209
+ OPENVOICE_MODELS/
210
+ PIPER_MODELS/
211
+ WHISPER_MODELS/
212
+ whisper_api_audio_parts/
213
+ uroman/
214
+ pdf_images/
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
SoniTranslate_Colab.ipynb ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4",
8
+ "include_colab_link": true
9
+ },
10
+ "kernelspec": {
11
+ "name": "python3",
12
+ "display_name": "Python 3"
13
+ },
14
+ "language_info": {
15
+ "name": "python"
16
+ },
17
+ "accelerator": "GPU"
18
+ },
19
+ "cells": [
20
+ {
21
+ "cell_type": "markdown",
22
+ "metadata": {
23
+ "id": "view-in-github",
24
+ "colab_type": "text"
25
+ },
26
+ "source": [
27
+ "<a href=\"https://colab.research.google.com/github/R3gm/SoniTranslate/blob/main/SoniTranslate_Colab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "markdown",
32
+ "source": [
33
+ "# SoniTranslate\n",
34
+ "\n",
35
+ "| Description | Link |\n",
36
+ "| ----------- | ---- |\n",
37
+ "| 🎉 Repository | [![GitHub Repository](https://img.shields.io/badge/GitHub-Repository-black?style=flat-square&logo=github)](https://github.com/R3gm/SoniTranslate/) |\n",
38
+ "| 🚀 Online Demo in HF | [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/r3gm/SoniTranslate_translate_audio_of_a_video_content) |\n",
39
+ "\n",
40
+ "\n"
41
+ ],
42
+ "metadata": {
43
+ "id": "8lw0EgLex-YZ"
44
+ }
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": null,
49
+ "metadata": {
50
+ "id": "LUgwm0rfx0_J",
51
+ "cellView": "form"
52
+ },
53
+ "outputs": [],
54
+ "source": [
55
+ "# @title Install requirements for SoniTranslate\n",
56
+ "!git clone https://github.com/r3gm/SoniTranslate.git\n",
57
+ "%cd SoniTranslate\n",
58
+ "\n",
59
+ "!apt install git-lfs\n",
60
+ "!git lfs install\n",
61
+ "\n",
62
+ "!sed -i 's|git+https://github.com/R3gm/whisperX.git@cuda_11_8|git+https://github.com/R3gm/whisperX.git@cuda_12_x|' requirements_base.txt\n",
63
+ "!pip install -q -r requirements_base.txt\n",
64
+ "!pip install -q -r requirements_extra.txt\n",
65
+ "!pip install -q ort-nightly-gpu --index-url=https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ort-cuda-12-nightly/pypi/simple/\n",
66
+ "\n",
67
+ "Install_PIPER_TTS = True # @param {type:\"boolean\"}\n",
68
+ "\n",
69
+ "if Install_PIPER_TTS:\n",
70
+ " !pip install -q piper-tts==1.2.0\n",
71
+ "\n",
72
+ "Install_Coqui_XTTS = True # @param {type:\"boolean\"}\n",
73
+ "\n",
74
+ "if Install_Coqui_XTTS:\n",
75
+ " !pip install -q -r requirements_xtts.txt\n",
76
+ " !pip install -q TTS==0.21.1 --no-deps"
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "markdown",
81
+ "source": [
82
+ "One important step is to accept the license agreement for using Pyannote. You need to have an account on Hugging Face and `accept the license to use the models`: https://huggingface.co/pyannote/speaker-diarization and https://huggingface.co/pyannote/segmentation\n",
83
+ "\n",
84
+ "\n",
85
+ "\n",
86
+ "\n",
87
+ "Get your KEY TOKEN here: https://hf.co/settings/tokens"
88
+ ],
89
+ "metadata": {
90
+ "id": "LTaTstXPXNg2"
91
+ }
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "source": [
96
+ "#@markdown # `RUN THE WEB APP`\n",
97
+ "YOUR_HF_TOKEN = \"\" #@param {type:'string'}\n",
98
+ "%env YOUR_HF_TOKEN={YOUR_HF_TOKEN}\n",
99
+ "theme = \"Taithrah/Minimal\" # @param [\"Taithrah/Minimal\", \"aliabid94/new-theme\", \"gstaff/xkcd\", \"ParityError/LimeFace\", \"abidlabs/pakistan\", \"rottenlittlecreature/Moon_Goblin\", \"ysharma/llamas\", \"gradio/dracula_revamped\"]\n",
100
+ "interface_language = \"english\" # @param ['arabic', 'azerbaijani', 'chinese_zh_cn', 'english', 'french', 'german', 'hindi', 'indonesian', 'italian', 'japanese', 'korean', 'marathi', 'polish', 'portuguese', 'russian', 'spanish', 'swedish', 'turkish', 'ukrainian', 'vietnamese']\n",
101
+ "verbosity_level = \"info\" # @param [\"debug\", \"info\", \"warning\", \"error\", \"critical\"]\n",
102
+ "\n",
103
+ "\n",
104
+ "%cd /content/SoniTranslate\n",
105
+ "!python app_rvc.py --theme {theme} --verbosity_level {verbosity_level} --language {interface_language} --public_url"
106
+ ],
107
+ "metadata": {
108
+ "id": "XkhXfaFw4R4J",
109
+ "cellView": "form"
110
+ },
111
+ "execution_count": null,
112
+ "outputs": []
113
+ },
114
+ {
115
+ "cell_type": "markdown",
116
+ "source": [
117
+ "Open the `public URL` when it appears"
118
+ ],
119
+ "metadata": {
120
+ "id": "KJW3KrhZJh0u"
121
+ }
122
+ }
123
+ ]
124
+ }
app.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ import os
2
+ os.system("python app_rvc.py --language french --theme aliabid94/new-theme")
app_rvc.py ADDED
The diff for this file is too large to render. See raw diff
 
assets/logo.jpeg ADDED
docs/windows_install.md ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Install Locally Windows
2
+
3
+ ### Before You Start
4
+
5
+ Before you start installing and using SoniTranslate, there are a few things you need to do:
6
+
7
+ 1. Install Microsoft Visual C++ Build Tools, MSVC and Windows 10 SDK:
8
+
9
+ * Go to the [Visual Studio downloads page](https://visualstudio.microsoft.com/visual-cpp-build-tools/); Or maybe you already have **Visual Studio Installer**? Open it. If you have it already click modify.
10
+ * Download and install the "Build Tools for Visual Studio" if you don't have it.
11
+ * During installation, under "Workloads", select "C++ build tools" and ensure the latest versions of "MSVCv142 - VS 2019 C++ x64/x86 build tools" and "Windows 10 SDK" are selected ("Windows 11 SDK" if you are using Windows 11); OR go to individual components and find those two listed.
12
+ * Complete the installation.
13
+
14
+ 2. Verify the NVIDIA driver on Windows using the command line:
15
+
16
+ * **Open Command Prompt:** Press `Win + R`, type `cmd`, then press `Enter`.
17
+
18
+ * **Type the command:** `nvidia-smi` and press `Enter`.
19
+
20
+ * **Look for "CUDA Version"** in the output.
21
+
22
+ ```
23
+ +-----------------------------------------------------------------------------+
24
+ | NVIDIA-SMI 522.25 Driver Version: 522.25 CUDA Version: 11.8 |
25
+ |-------------------------------+----------------------+----------------------+
26
+ ```
27
+
28
+ 3. If you see that your CUDA version is less than 11.8, you should update your NVIDIA driver. Visit the NVIDIA website's driver download page (https://www.nvidia.com/Download/index.aspx) and enter your graphics card information.
29
+
30
+ 4. Accept the license agreement for using Pyannote. You need to have an account on Hugging Face and `accept the license to use the models`: https://huggingface.co/pyannote/speaker-diarization and https://huggingface.co/pyannote/segmentation
31
+ 5. Create a [huggingface token](https://huggingface.co/settings/tokens). Hugging Face is a natural language processing platform that provides access to state-of-the-art models and tools. You will need to create a token in order to use some of the automatic model download features in SoniTranslate. Follow the instructions on the Hugging Face website to create a token.
32
+ 6. Install [Anaconda](https://www.anaconda.com/) or [Miniconda](https://docs.anaconda.com/free/miniconda/miniconda-install/). Anaconda is a free and open-source distribution of Python and R. It includes a package manager called conda that makes it easy to install and manage Python environments and packages. Follow the instructions on the Anaconda website to download and install Anaconda on your system.
33
+ 7. Install Git for your system. Git is a version control system that helps you track changes to your code and collaborate with other developers. You can install Git with Anaconda by running `conda install -c anaconda git -y` in your terminal (Do this after step 1 in the following section.). If you have trouble installing Git via Anaconda, you can use the following link instead:
34
+ - [Git for Windows](https://git-scm.com/download/win)
35
+
36
+ Once you have completed these steps, you will be ready to install SoniTranslate.
37
+
38
+ ### Getting Started
39
+
40
+ To install SoniTranslate, follow these steps:
41
+
42
+ 1. Create a suitable anaconda environment for SoniTranslate and activate it:
43
+
44
+ ```
45
+ conda create -n sonitr python=3.10 -y
46
+ conda activate sonitr
47
+ ```
48
+
49
+ 2. Clone this github repository and navigate to it:
50
+ ```
51
+ git clone https://github.com/r3gm/SoniTranslate.git
52
+ cd SoniTranslate
53
+ ```
54
+ 3. Install CUDA Toolkit 11.8.0
55
+
56
+ ```
57
+ conda install -c "nvidia/label/cuda-11.8.0" cuda-toolkit -y
58
+ ```
59
+
60
+ 4. Install PyTorch using conda
61
+ ```
62
+ conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia -y
63
+ ```
64
+
65
+ 5. Install required packages:
66
+
67
+ ```
68
+ pip install -r requirements_base.txt -v
69
+ pip install -r requirements_extra.txt -v
70
+ pip install onnxruntime-gpu
71
+ ```
72
+
73
+ 6. Install [ffmpeg](https://ffmpeg.org/download.html). FFmpeg is a free software project that produces libraries and programs for handling multimedia data. You will need it to process audio and video files. You can install ffmpeg with Anaconda by running `conda install -y ffmpeg` in your terminal (recommended). If you have trouble installing ffmpeg via Anaconda, you can use the following link instead: (https://ffmpeg.org/ffmpeg.html). Once it is installed, make sure it is in your PATH by running `ffmpeg -h` in your terminal. If you don't get an error message, you're good to go.
74
+
75
+ 7. Optional install:
76
+
77
+ After installing FFmpeg, you can install these optional packages.
78
+
79
+ [Coqui XTTS](https://github.com/coqui-ai/TTS) is a text-to-speech (TTS) model that lets you generate realistic voices in different languages. It can clone voices with just a short audio clip, even speak in a different language! It's like having a personal voice mimic for any text you need spoken.
80
+
81
+ ```
82
+ pip install -q -r requirements_xtts.txt
83
+ pip install -q TTS==0.21.1 --no-deps
84
+ ```
85
+
86
+ [Piper TTS](https://github.com/rhasspy/piper) is a fast, local neural text to speech system that sounds great and is optimized for the Raspberry Pi 4. Piper is used in a variety of projects. Voices are trained with VITS and exported to the onnxruntime.
87
+
88
+ 🚧 For Windows users, it's important to note that the Python module piper-tts is not fully supported on this operating system. While it works smoothly on Linux, Windows compatibility is currently experimental. If you still wish to install it on Windows, you can follow this experimental method:
89
+
90
+ ```
91
+ pip install https://github.com/R3gm/piper-phonemize/releases/download/1.2.0/piper_phonemize-1.2.0-cp310-cp310-win_amd64.whl
92
+ pip install sherpa-onnx==1.9.12
93
+ pip install piper-tts==1.2.0 --no-deps
94
+ ```
95
+
96
+ 8. Setting your [Hugging Face token](https://huggingface.co/settings/tokens) as an environment variable in quotes:
97
+
98
+ ```
99
+ conda env config vars set YOUR_HF_TOKEN="YOUR_HUGGING_FACE_TOKEN_HERE"
100
+ conda deactivate
101
+ ```
102
+
103
+
104
+ ### Running SoniTranslate
105
+
106
+ To run SoniTranslate locally, make sure the `sonitr` conda environment is active:
107
+
108
+ ```
109
+ conda activate sonitr
110
+ ```
111
+
112
+ Then navigate to the `SoniTranslate` folder and run either the `app_rvc.py`
113
+
114
+ ```
115
+ python app_rvc.py
116
+ ```
117
+ When the `local URL` `http://127.0.0.1:7860` is displayed in the terminal, simply open this URL in your web browser to access the SoniTranslate interface.
118
+
119
+ ### Stop and close SoniTranslate.
120
+
121
+ In most environments, you can stop the execution by pressing Ctrl+C in the terminal where you launched the script `app_rvc.py`. This will interrupt the program and stop the Gradio app.
122
+ To deactivate the Conda environment, you can use the following command:
123
+
124
+ ```
125
+ conda deactivate
126
+ ```
127
+
128
+ This will deactivate the currently active Conda environment sonitr, and you'll return to the base environment or the global Python environment.
129
+
130
+ ### Starting Over
131
+
132
+ If you need to start over from scratch, you can delete the `SoniTranslate` folder and remove the `sonitr` conda environment with the following set of commands:
133
+
134
+ ```
135
+ conda deactivate
136
+ conda env remove -n sonitr
137
+ ```
138
+
139
+ With the `sonitr` environment removed, you can start over with a fresh installation.
140
+
141
+ ### Notes
142
+ - To use OpenAI's GPT API for translation, set up your OpenAI API key as an environment variable in quotes:
143
+
144
+ ```
145
+ conda activate sonitr
146
+ conda env config vars set OPENAI_API_KEY="your-api-key-here"
147
+ conda deactivate
148
+ ```
149
+
150
+ - Alternatively, you can install the CUDA Toolkit 11.8.0 directly on your system [CUDA Toolkit 11.8.0](https://developer.nvidia.com/cuda-11-8-0-download-archive).
lib/audio.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ffmpeg
2
+ import numpy as np
3
+
4
+
5
+ def load_audio(file, sr):
6
+ try:
7
+ # https://github.com/openai/whisper/blob/main/whisper/audio.py#L26
8
+ # This launches a subprocess to decode audio while down-mixing and resampling as necessary.
9
+ # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
10
+ file = (
11
+ file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
12
+ ) # To prevent beginners from copying paths with leading or trailing spaces, quotation marks, and line breaks.
13
+ out, _ = (
14
+ ffmpeg.input(file, threads=0)
15
+ .output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr)
16
+ .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
17
+ )
18
+ except Exception as e:
19
+ raise RuntimeError(f"Failed to load audio: {e}")
20
+
21
+ return np.frombuffer(out, np.float32).flatten()
lib/infer_pack/attentions.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from lib.infer_pack import commons
9
+ from lib.infer_pack import modules
10
+ from lib.infer_pack.modules import LayerNorm
11
+
12
+
13
+ class Encoder(nn.Module):
14
+ def __init__(
15
+ self,
16
+ hidden_channels,
17
+ filter_channels,
18
+ n_heads,
19
+ n_layers,
20
+ kernel_size=1,
21
+ p_dropout=0.0,
22
+ window_size=10,
23
+ **kwargs
24
+ ):
25
+ super().__init__()
26
+ self.hidden_channels = hidden_channels
27
+ self.filter_channels = filter_channels
28
+ self.n_heads = n_heads
29
+ self.n_layers = n_layers
30
+ self.kernel_size = kernel_size
31
+ self.p_dropout = p_dropout
32
+ self.window_size = window_size
33
+
34
+ self.drop = nn.Dropout(p_dropout)
35
+ self.attn_layers = nn.ModuleList()
36
+ self.norm_layers_1 = nn.ModuleList()
37
+ self.ffn_layers = nn.ModuleList()
38
+ self.norm_layers_2 = nn.ModuleList()
39
+ for i in range(self.n_layers):
40
+ self.attn_layers.append(
41
+ MultiHeadAttention(
42
+ hidden_channels,
43
+ hidden_channels,
44
+ n_heads,
45
+ p_dropout=p_dropout,
46
+ window_size=window_size,
47
+ )
48
+ )
49
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
50
+ self.ffn_layers.append(
51
+ FFN(
52
+ hidden_channels,
53
+ hidden_channels,
54
+ filter_channels,
55
+ kernel_size,
56
+ p_dropout=p_dropout,
57
+ )
58
+ )
59
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
60
+
61
+ def forward(self, x, x_mask):
62
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
63
+ x = x * x_mask
64
+ for i in range(self.n_layers):
65
+ y = self.attn_layers[i](x, x, attn_mask)
66
+ y = self.drop(y)
67
+ x = self.norm_layers_1[i](x + y)
68
+
69
+ y = self.ffn_layers[i](x, x_mask)
70
+ y = self.drop(y)
71
+ x = self.norm_layers_2[i](x + y)
72
+ x = x * x_mask
73
+ return x
74
+
75
+
76
+ class Decoder(nn.Module):
77
+ def __init__(
78
+ self,
79
+ hidden_channels,
80
+ filter_channels,
81
+ n_heads,
82
+ n_layers,
83
+ kernel_size=1,
84
+ p_dropout=0.0,
85
+ proximal_bias=False,
86
+ proximal_init=True,
87
+ **kwargs
88
+ ):
89
+ super().__init__()
90
+ self.hidden_channels = hidden_channels
91
+ self.filter_channels = filter_channels
92
+ self.n_heads = n_heads
93
+ self.n_layers = n_layers
94
+ self.kernel_size = kernel_size
95
+ self.p_dropout = p_dropout
96
+ self.proximal_bias = proximal_bias
97
+ self.proximal_init = proximal_init
98
+
99
+ self.drop = nn.Dropout(p_dropout)
100
+ self.self_attn_layers = nn.ModuleList()
101
+ self.norm_layers_0 = nn.ModuleList()
102
+ self.encdec_attn_layers = nn.ModuleList()
103
+ self.norm_layers_1 = nn.ModuleList()
104
+ self.ffn_layers = nn.ModuleList()
105
+ self.norm_layers_2 = nn.ModuleList()
106
+ for i in range(self.n_layers):
107
+ self.self_attn_layers.append(
108
+ MultiHeadAttention(
109
+ hidden_channels,
110
+ hidden_channels,
111
+ n_heads,
112
+ p_dropout=p_dropout,
113
+ proximal_bias=proximal_bias,
114
+ proximal_init=proximal_init,
115
+ )
116
+ )
117
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
118
+ self.encdec_attn_layers.append(
119
+ MultiHeadAttention(
120
+ hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
121
+ )
122
+ )
123
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
124
+ self.ffn_layers.append(
125
+ FFN(
126
+ hidden_channels,
127
+ hidden_channels,
128
+ filter_channels,
129
+ kernel_size,
130
+ p_dropout=p_dropout,
131
+ causal=True,
132
+ )
133
+ )
134
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
135
+
136
+ def forward(self, x, x_mask, h, h_mask):
137
+ """
138
+ x: decoder input
139
+ h: encoder output
140
+ """
141
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
142
+ device=x.device, dtype=x.dtype
143
+ )
144
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
145
+ x = x * x_mask
146
+ for i in range(self.n_layers):
147
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
148
+ y = self.drop(y)
149
+ x = self.norm_layers_0[i](x + y)
150
+
151
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
152
+ y = self.drop(y)
153
+ x = self.norm_layers_1[i](x + y)
154
+
155
+ y = self.ffn_layers[i](x, x_mask)
156
+ y = self.drop(y)
157
+ x = self.norm_layers_2[i](x + y)
158
+ x = x * x_mask
159
+ return x
160
+
161
+
162
+ class MultiHeadAttention(nn.Module):
163
+ def __init__(
164
+ self,
165
+ channels,
166
+ out_channels,
167
+ n_heads,
168
+ p_dropout=0.0,
169
+ window_size=None,
170
+ heads_share=True,
171
+ block_length=None,
172
+ proximal_bias=False,
173
+ proximal_init=False,
174
+ ):
175
+ super().__init__()
176
+ assert channels % n_heads == 0
177
+
178
+ self.channels = channels
179
+ self.out_channels = out_channels
180
+ self.n_heads = n_heads
181
+ self.p_dropout = p_dropout
182
+ self.window_size = window_size
183
+ self.heads_share = heads_share
184
+ self.block_length = block_length
185
+ self.proximal_bias = proximal_bias
186
+ self.proximal_init = proximal_init
187
+ self.attn = None
188
+
189
+ self.k_channels = channels // n_heads
190
+ self.conv_q = nn.Conv1d(channels, channels, 1)
191
+ self.conv_k = nn.Conv1d(channels, channels, 1)
192
+ self.conv_v = nn.Conv1d(channels, channels, 1)
193
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
194
+ self.drop = nn.Dropout(p_dropout)
195
+
196
+ if window_size is not None:
197
+ n_heads_rel = 1 if heads_share else n_heads
198
+ rel_stddev = self.k_channels**-0.5
199
+ self.emb_rel_k = nn.Parameter(
200
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
201
+ * rel_stddev
202
+ )
203
+ self.emb_rel_v = nn.Parameter(
204
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
205
+ * rel_stddev
206
+ )
207
+
208
+ nn.init.xavier_uniform_(self.conv_q.weight)
209
+ nn.init.xavier_uniform_(self.conv_k.weight)
210
+ nn.init.xavier_uniform_(self.conv_v.weight)
211
+ if proximal_init:
212
+ with torch.no_grad():
213
+ self.conv_k.weight.copy_(self.conv_q.weight)
214
+ self.conv_k.bias.copy_(self.conv_q.bias)
215
+
216
+ def forward(self, x, c, attn_mask=None):
217
+ q = self.conv_q(x)
218
+ k = self.conv_k(c)
219
+ v = self.conv_v(c)
220
+
221
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
222
+
223
+ x = self.conv_o(x)
224
+ return x
225
+
226
+ def attention(self, query, key, value, mask=None):
227
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
228
+ b, d, t_s, t_t = (*key.size(), query.size(2))
229
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
230
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
231
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
232
+
233
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
234
+ if self.window_size is not None:
235
+ assert (
236
+ t_s == t_t
237
+ ), "Relative attention is only available for self-attention."
238
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
239
+ rel_logits = self._matmul_with_relative_keys(
240
+ query / math.sqrt(self.k_channels), key_relative_embeddings
241
+ )
242
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
243
+ scores = scores + scores_local
244
+ if self.proximal_bias:
245
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
246
+ scores = scores + self._attention_bias_proximal(t_s).to(
247
+ device=scores.device, dtype=scores.dtype
248
+ )
249
+ if mask is not None:
250
+ scores = scores.masked_fill(mask == 0, -1e4)
251
+ if self.block_length is not None:
252
+ assert (
253
+ t_s == t_t
254
+ ), "Local attention is only available for self-attention."
255
+ block_mask = (
256
+ torch.ones_like(scores)
257
+ .triu(-self.block_length)
258
+ .tril(self.block_length)
259
+ )
260
+ scores = scores.masked_fill(block_mask == 0, -1e4)
261
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
262
+ p_attn = self.drop(p_attn)
263
+ output = torch.matmul(p_attn, value)
264
+ if self.window_size is not None:
265
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
266
+ value_relative_embeddings = self._get_relative_embeddings(
267
+ self.emb_rel_v, t_s
268
+ )
269
+ output = output + self._matmul_with_relative_values(
270
+ relative_weights, value_relative_embeddings
271
+ )
272
+ output = (
273
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
274
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
275
+ return output, p_attn
276
+
277
+ def _matmul_with_relative_values(self, x, y):
278
+ """
279
+ x: [b, h, l, m]
280
+ y: [h or 1, m, d]
281
+ ret: [b, h, l, d]
282
+ """
283
+ ret = torch.matmul(x, y.unsqueeze(0))
284
+ return ret
285
+
286
+ def _matmul_with_relative_keys(self, x, y):
287
+ """
288
+ x: [b, h, l, d]
289
+ y: [h or 1, m, d]
290
+ ret: [b, h, l, m]
291
+ """
292
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
293
+ return ret
294
+
295
+ def _get_relative_embeddings(self, relative_embeddings, length):
296
+ max_relative_position = 2 * self.window_size + 1
297
+ # Pad first before slice to avoid using cond ops.
298
+ pad_length = max(length - (self.window_size + 1), 0)
299
+ slice_start_position = max((self.window_size + 1) - length, 0)
300
+ slice_end_position = slice_start_position + 2 * length - 1
301
+ if pad_length > 0:
302
+ padded_relative_embeddings = F.pad(
303
+ relative_embeddings,
304
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
305
+ )
306
+ else:
307
+ padded_relative_embeddings = relative_embeddings
308
+ used_relative_embeddings = padded_relative_embeddings[
309
+ :, slice_start_position:slice_end_position
310
+ ]
311
+ return used_relative_embeddings
312
+
313
+ def _relative_position_to_absolute_position(self, x):
314
+ """
315
+ x: [b, h, l, 2*l-1]
316
+ ret: [b, h, l, l]
317
+ """
318
+ batch, heads, length, _ = x.size()
319
+ # Concat columns of pad to shift from relative to absolute indexing.
320
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
321
+
322
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
323
+ x_flat = x.view([batch, heads, length * 2 * length])
324
+ x_flat = F.pad(
325
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
326
+ )
327
+
328
+ # Reshape and slice out the padded elements.
329
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
330
+ :, :, :length, length - 1 :
331
+ ]
332
+ return x_final
333
+
334
+ def _absolute_position_to_relative_position(self, x):
335
+ """
336
+ x: [b, h, l, l]
337
+ ret: [b, h, l, 2*l-1]
338
+ """
339
+ batch, heads, length, _ = x.size()
340
+ # padd along column
341
+ x = F.pad(
342
+ x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
343
+ )
344
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
345
+ # add 0's in the beginning that will skew the elements after reshape
346
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
347
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
348
+ return x_final
349
+
350
+ def _attention_bias_proximal(self, length):
351
+ """Bias for self-attention to encourage attention to close positions.
352
+ Args:
353
+ length: an integer scalar.
354
+ Returns:
355
+ a Tensor with shape [1, 1, length, length]
356
+ """
357
+ r = torch.arange(length, dtype=torch.float32)
358
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
359
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
360
+
361
+
362
+ class FFN(nn.Module):
363
+ def __init__(
364
+ self,
365
+ in_channels,
366
+ out_channels,
367
+ filter_channels,
368
+ kernel_size,
369
+ p_dropout=0.0,
370
+ activation=None,
371
+ causal=False,
372
+ ):
373
+ super().__init__()
374
+ self.in_channels = in_channels
375
+ self.out_channels = out_channels
376
+ self.filter_channels = filter_channels
377
+ self.kernel_size = kernel_size
378
+ self.p_dropout = p_dropout
379
+ self.activation = activation
380
+ self.causal = causal
381
+
382
+ if causal:
383
+ self.padding = self._causal_padding
384
+ else:
385
+ self.padding = self._same_padding
386
+
387
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
388
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
389
+ self.drop = nn.Dropout(p_dropout)
390
+
391
+ def forward(self, x, x_mask):
392
+ x = self.conv_1(self.padding(x * x_mask))
393
+ if self.activation == "gelu":
394
+ x = x * torch.sigmoid(1.702 * x)
395
+ else:
396
+ x = torch.relu(x)
397
+ x = self.drop(x)
398
+ x = self.conv_2(self.padding(x * x_mask))
399
+ return x * x_mask
400
+
401
+ def _causal_padding(self, x):
402
+ if self.kernel_size == 1:
403
+ return x
404
+ pad_l = self.kernel_size - 1
405
+ pad_r = 0
406
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
407
+ x = F.pad(x, commons.convert_pad_shape(padding))
408
+ return x
409
+
410
+ def _same_padding(self, x):
411
+ if self.kernel_size == 1:
412
+ return x
413
+ pad_l = (self.kernel_size - 1) // 2
414
+ pad_r = self.kernel_size // 2
415
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
416
+ x = F.pad(x, commons.convert_pad_shape(padding))
417
+ return x
lib/infer_pack/commons.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+
8
+ def init_weights(m, mean=0.0, std=0.01):
9
+ classname = m.__class__.__name__
10
+ if classname.find("Conv") != -1:
11
+ m.weight.data.normal_(mean, std)
12
+
13
+
14
+ def get_padding(kernel_size, dilation=1):
15
+ return int((kernel_size * dilation - dilation) / 2)
16
+
17
+
18
+ def convert_pad_shape(pad_shape):
19
+ l = pad_shape[::-1]
20
+ pad_shape = [item for sublist in l for item in sublist]
21
+ return pad_shape
22
+
23
+
24
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
25
+ """KL(P||Q)"""
26
+ kl = (logs_q - logs_p) - 0.5
27
+ kl += (
28
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
29
+ )
30
+ return kl
31
+
32
+
33
+ def rand_gumbel(shape):
34
+ """Sample from the Gumbel distribution, protect from overflows."""
35
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
36
+ return -torch.log(-torch.log(uniform_samples))
37
+
38
+
39
+ def rand_gumbel_like(x):
40
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
41
+ return g
42
+
43
+
44
+ def slice_segments(x, ids_str, segment_size=4):
45
+ ret = torch.zeros_like(x[:, :, :segment_size])
46
+ for i in range(x.size(0)):
47
+ idx_str = ids_str[i]
48
+ idx_end = idx_str + segment_size
49
+ ret[i] = x[i, :, idx_str:idx_end]
50
+ return ret
51
+
52
+
53
+ def slice_segments2(x, ids_str, segment_size=4):
54
+ ret = torch.zeros_like(x[:, :segment_size])
55
+ for i in range(x.size(0)):
56
+ idx_str = ids_str[i]
57
+ idx_end = idx_str + segment_size
58
+ ret[i] = x[i, idx_str:idx_end]
59
+ return ret
60
+
61
+
62
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
63
+ b, d, t = x.size()
64
+ if x_lengths is None:
65
+ x_lengths = t
66
+ ids_str_max = x_lengths - segment_size + 1
67
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
68
+ ret = slice_segments(x, ids_str, segment_size)
69
+ return ret, ids_str
70
+
71
+
72
+ def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
73
+ position = torch.arange(length, dtype=torch.float)
74
+ num_timescales = channels // 2
75
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
76
+ num_timescales - 1
77
+ )
78
+ inv_timescales = min_timescale * torch.exp(
79
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
80
+ )
81
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
82
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
83
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
84
+ signal = signal.view(1, channels, length)
85
+ return signal
86
+
87
+
88
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
89
+ b, channels, length = x.size()
90
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
91
+ return x + signal.to(dtype=x.dtype, device=x.device)
92
+
93
+
94
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
95
+ b, channels, length = x.size()
96
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
97
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
98
+
99
+
100
+ def subsequent_mask(length):
101
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
102
+ return mask
103
+
104
+
105
+ @torch.jit.script
106
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
107
+ n_channels_int = n_channels[0]
108
+ in_act = input_a + input_b
109
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
110
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
111
+ acts = t_act * s_act
112
+ return acts
113
+
114
+
115
+ def convert_pad_shape(pad_shape):
116
+ l = pad_shape[::-1]
117
+ pad_shape = [item for sublist in l for item in sublist]
118
+ return pad_shape
119
+
120
+
121
+ def shift_1d(x):
122
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
123
+ return x
124
+
125
+
126
+ def sequence_mask(length, max_length=None):
127
+ if max_length is None:
128
+ max_length = length.max()
129
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
130
+ return x.unsqueeze(0) < length.unsqueeze(1)
131
+
132
+
133
+ def generate_path(duration, mask):
134
+ """
135
+ duration: [b, 1, t_x]
136
+ mask: [b, 1, t_y, t_x]
137
+ """
138
+ device = duration.device
139
+
140
+ b, _, t_y, t_x = mask.shape
141
+ cum_duration = torch.cumsum(duration, -1)
142
+
143
+ cum_duration_flat = cum_duration.view(b * t_x)
144
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
145
+ path = path.view(b, t_x, t_y)
146
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
147
+ path = path.unsqueeze(1).transpose(2, 3) * mask
148
+ return path
149
+
150
+
151
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
152
+ if isinstance(parameters, torch.Tensor):
153
+ parameters = [parameters]
154
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
155
+ norm_type = float(norm_type)
156
+ if clip_value is not None:
157
+ clip_value = float(clip_value)
158
+
159
+ total_norm = 0
160
+ for p in parameters:
161
+ param_norm = p.grad.data.norm(norm_type)
162
+ total_norm += param_norm.item() ** norm_type
163
+ if clip_value is not None:
164
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
165
+ total_norm = total_norm ** (1.0 / norm_type)
166
+ return total_norm
lib/infer_pack/models.py ADDED
@@ -0,0 +1,1142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math, pdb, os
2
+ from time import time as ttime
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from lib.infer_pack import modules
7
+ from lib.infer_pack import attentions
8
+ from lib.infer_pack import commons
9
+ from lib.infer_pack.commons import init_weights, get_padding
10
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
11
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
+ from lib.infer_pack.commons import init_weights
13
+ import numpy as np
14
+ from lib.infer_pack import commons
15
+
16
+
17
+ class TextEncoder256(nn.Module):
18
+ def __init__(
19
+ self,
20
+ out_channels,
21
+ hidden_channels,
22
+ filter_channels,
23
+ n_heads,
24
+ n_layers,
25
+ kernel_size,
26
+ p_dropout,
27
+ f0=True,
28
+ ):
29
+ super().__init__()
30
+ self.out_channels = out_channels
31
+ self.hidden_channels = hidden_channels
32
+ self.filter_channels = filter_channels
33
+ self.n_heads = n_heads
34
+ self.n_layers = n_layers
35
+ self.kernel_size = kernel_size
36
+ self.p_dropout = p_dropout
37
+ self.emb_phone = nn.Linear(256, hidden_channels)
38
+ self.lrelu = nn.LeakyReLU(0.1, inplace=True)
39
+ if f0 == True:
40
+ self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
41
+ self.encoder = attentions.Encoder(
42
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
43
+ )
44
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
45
+
46
+ def forward(self, phone, pitch, lengths):
47
+ if pitch == None:
48
+ x = self.emb_phone(phone)
49
+ else:
50
+ x = self.emb_phone(phone) + self.emb_pitch(pitch)
51
+ x = x * math.sqrt(self.hidden_channels) # [b, t, h]
52
+ x = self.lrelu(x)
53
+ x = torch.transpose(x, 1, -1) # [b, h, t]
54
+ x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
55
+ x.dtype
56
+ )
57
+ x = self.encoder(x * x_mask, x_mask)
58
+ stats = self.proj(x) * x_mask
59
+
60
+ m, logs = torch.split(stats, self.out_channels, dim=1)
61
+ return m, logs, x_mask
62
+
63
+
64
+ class TextEncoder768(nn.Module):
65
+ def __init__(
66
+ self,
67
+ out_channels,
68
+ hidden_channels,
69
+ filter_channels,
70
+ n_heads,
71
+ n_layers,
72
+ kernel_size,
73
+ p_dropout,
74
+ f0=True,
75
+ ):
76
+ super().__init__()
77
+ self.out_channels = out_channels
78
+ self.hidden_channels = hidden_channels
79
+ self.filter_channels = filter_channels
80
+ self.n_heads = n_heads
81
+ self.n_layers = n_layers
82
+ self.kernel_size = kernel_size
83
+ self.p_dropout = p_dropout
84
+ self.emb_phone = nn.Linear(768, hidden_channels)
85
+ self.lrelu = nn.LeakyReLU(0.1, inplace=True)
86
+ if f0 == True:
87
+ self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
88
+ self.encoder = attentions.Encoder(
89
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
90
+ )
91
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
92
+
93
+ def forward(self, phone, pitch, lengths):
94
+ if pitch == None:
95
+ x = self.emb_phone(phone)
96
+ else:
97
+ x = self.emb_phone(phone) + self.emb_pitch(pitch)
98
+ x = x * math.sqrt(self.hidden_channels) # [b, t, h]
99
+ x = self.lrelu(x)
100
+ x = torch.transpose(x, 1, -1) # [b, h, t]
101
+ x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
102
+ x.dtype
103
+ )
104
+ x = self.encoder(x * x_mask, x_mask)
105
+ stats = self.proj(x) * x_mask
106
+
107
+ m, logs = torch.split(stats, self.out_channels, dim=1)
108
+ return m, logs, x_mask
109
+
110
+
111
+ class ResidualCouplingBlock(nn.Module):
112
+ def __init__(
113
+ self,
114
+ channels,
115
+ hidden_channels,
116
+ kernel_size,
117
+ dilation_rate,
118
+ n_layers,
119
+ n_flows=4,
120
+ gin_channels=0,
121
+ ):
122
+ super().__init__()
123
+ self.channels = channels
124
+ self.hidden_channels = hidden_channels
125
+ self.kernel_size = kernel_size
126
+ self.dilation_rate = dilation_rate
127
+ self.n_layers = n_layers
128
+ self.n_flows = n_flows
129
+ self.gin_channels = gin_channels
130
+
131
+ self.flows = nn.ModuleList()
132
+ for i in range(n_flows):
133
+ self.flows.append(
134
+ modules.ResidualCouplingLayer(
135
+ channels,
136
+ hidden_channels,
137
+ kernel_size,
138
+ dilation_rate,
139
+ n_layers,
140
+ gin_channels=gin_channels,
141
+ mean_only=True,
142
+ )
143
+ )
144
+ self.flows.append(modules.Flip())
145
+
146
+ def forward(self, x, x_mask, g=None, reverse=False):
147
+ if not reverse:
148
+ for flow in self.flows:
149
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
150
+ else:
151
+ for flow in reversed(self.flows):
152
+ x = flow(x, x_mask, g=g, reverse=reverse)
153
+ return x
154
+
155
+ def remove_weight_norm(self):
156
+ for i in range(self.n_flows):
157
+ self.flows[i * 2].remove_weight_norm()
158
+
159
+
160
+ class PosteriorEncoder(nn.Module):
161
+ def __init__(
162
+ self,
163
+ in_channels,
164
+ out_channels,
165
+ hidden_channels,
166
+ kernel_size,
167
+ dilation_rate,
168
+ n_layers,
169
+ gin_channels=0,
170
+ ):
171
+ super().__init__()
172
+ self.in_channels = in_channels
173
+ self.out_channels = out_channels
174
+ self.hidden_channels = hidden_channels
175
+ self.kernel_size = kernel_size
176
+ self.dilation_rate = dilation_rate
177
+ self.n_layers = n_layers
178
+ self.gin_channels = gin_channels
179
+
180
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
181
+ self.enc = modules.WN(
182
+ hidden_channels,
183
+ kernel_size,
184
+ dilation_rate,
185
+ n_layers,
186
+ gin_channels=gin_channels,
187
+ )
188
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
189
+
190
+ def forward(self, x, x_lengths, g=None):
191
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
192
+ x.dtype
193
+ )
194
+ x = self.pre(x) * x_mask
195
+ x = self.enc(x, x_mask, g=g)
196
+ stats = self.proj(x) * x_mask
197
+ m, logs = torch.split(stats, self.out_channels, dim=1)
198
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
199
+ return z, m, logs, x_mask
200
+
201
+ def remove_weight_norm(self):
202
+ self.enc.remove_weight_norm()
203
+
204
+
205
+ class Generator(torch.nn.Module):
206
+ def __init__(
207
+ self,
208
+ initial_channel,
209
+ resblock,
210
+ resblock_kernel_sizes,
211
+ resblock_dilation_sizes,
212
+ upsample_rates,
213
+ upsample_initial_channel,
214
+ upsample_kernel_sizes,
215
+ gin_channels=0,
216
+ ):
217
+ super(Generator, self).__init__()
218
+ self.num_kernels = len(resblock_kernel_sizes)
219
+ self.num_upsamples = len(upsample_rates)
220
+ self.conv_pre = Conv1d(
221
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
222
+ )
223
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
224
+
225
+ self.ups = nn.ModuleList()
226
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
227
+ self.ups.append(
228
+ weight_norm(
229
+ ConvTranspose1d(
230
+ upsample_initial_channel // (2**i),
231
+ upsample_initial_channel // (2 ** (i + 1)),
232
+ k,
233
+ u,
234
+ padding=(k - u) // 2,
235
+ )
236
+ )
237
+ )
238
+
239
+ self.resblocks = nn.ModuleList()
240
+ for i in range(len(self.ups)):
241
+ ch = upsample_initial_channel // (2 ** (i + 1))
242
+ for j, (k, d) in enumerate(
243
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
244
+ ):
245
+ self.resblocks.append(resblock(ch, k, d))
246
+
247
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
248
+ self.ups.apply(init_weights)
249
+
250
+ if gin_channels != 0:
251
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
252
+
253
+ def forward(self, x, g=None):
254
+ x = self.conv_pre(x)
255
+ if g is not None:
256
+ x = x + self.cond(g)
257
+
258
+ for i in range(self.num_upsamples):
259
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
260
+ x = self.ups[i](x)
261
+ xs = None
262
+ for j in range(self.num_kernels):
263
+ if xs is None:
264
+ xs = self.resblocks[i * self.num_kernels + j](x)
265
+ else:
266
+ xs += self.resblocks[i * self.num_kernels + j](x)
267
+ x = xs / self.num_kernels
268
+ x = F.leaky_relu(x)
269
+ x = self.conv_post(x)
270
+ x = torch.tanh(x)
271
+
272
+ return x
273
+
274
+ def remove_weight_norm(self):
275
+ for l in self.ups:
276
+ remove_weight_norm(l)
277
+ for l in self.resblocks:
278
+ l.remove_weight_norm()
279
+
280
+
281
+ class SineGen(torch.nn.Module):
282
+ """Definition of sine generator
283
+ SineGen(samp_rate, harmonic_num = 0,
284
+ sine_amp = 0.1, noise_std = 0.003,
285
+ voiced_threshold = 0,
286
+ flag_for_pulse=False)
287
+ samp_rate: sampling rate in Hz
288
+ harmonic_num: number of harmonic overtones (default 0)
289
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
290
+ noise_std: std of Gaussian noise (default 0.003)
291
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
292
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
293
+ Note: when flag_for_pulse is True, the first time step of a voiced
294
+ segment is always sin(np.pi) or cos(0)
295
+ """
296
+
297
+ def __init__(
298
+ self,
299
+ samp_rate,
300
+ harmonic_num=0,
301
+ sine_amp=0.1,
302
+ noise_std=0.003,
303
+ voiced_threshold=0,
304
+ flag_for_pulse=False,
305
+ ):
306
+ super(SineGen, self).__init__()
307
+ self.sine_amp = sine_amp
308
+ self.noise_std = noise_std
309
+ self.harmonic_num = harmonic_num
310
+ self.dim = self.harmonic_num + 1
311
+ self.sampling_rate = samp_rate
312
+ self.voiced_threshold = voiced_threshold
313
+
314
+ def _f02uv(self, f0):
315
+ # generate uv signal
316
+ uv = torch.ones_like(f0)
317
+ uv = uv * (f0 > self.voiced_threshold)
318
+ return uv
319
+
320
+ def forward(self, f0, upp):
321
+ """sine_tensor, uv = forward(f0)
322
+ input F0: tensor(batchsize=1, length, dim=1)
323
+ f0 for unvoiced steps should be 0
324
+ output sine_tensor: tensor(batchsize=1, length, dim)
325
+ output uv: tensor(batchsize=1, length, 1)
326
+ """
327
+ with torch.no_grad():
328
+ f0 = f0[:, None].transpose(1, 2)
329
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
330
+ # fundamental component
331
+ f0_buf[:, :, 0] = f0[:, :, 0]
332
+ for idx in np.arange(self.harmonic_num):
333
+ f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
334
+ idx + 2
335
+ ) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
336
+ rad_values = (f0_buf / self.sampling_rate) % 1 ###%1 means that the product of n_har cannot be post-processed and optimized
337
+ rand_ini = torch.rand(
338
+ f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device
339
+ )
340
+ rand_ini[:, 0] = 0
341
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
342
+ tmp_over_one = torch.cumsum(rad_values, 1) # % 1 #####%1 means that the following cumsum can no longer be optimized
343
+ tmp_over_one *= upp
344
+ tmp_over_one = F.interpolate(
345
+ tmp_over_one.transpose(2, 1),
346
+ scale_factor=upp,
347
+ mode="linear",
348
+ align_corners=True,
349
+ ).transpose(2, 1)
350
+ rad_values = F.interpolate(
351
+ rad_values.transpose(2, 1), scale_factor=upp, mode="nearest"
352
+ ).transpose(
353
+ 2, 1
354
+ ) #######
355
+ tmp_over_one %= 1
356
+ tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
357
+ cumsum_shift = torch.zeros_like(rad_values)
358
+ cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
359
+ sine_waves = torch.sin(
360
+ torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi
361
+ )
362
+ sine_waves = sine_waves * self.sine_amp
363
+ uv = self._f02uv(f0)
364
+ uv = F.interpolate(
365
+ uv.transpose(2, 1), scale_factor=upp, mode="nearest"
366
+ ).transpose(2, 1)
367
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
368
+ noise = noise_amp * torch.randn_like(sine_waves)
369
+ sine_waves = sine_waves * uv + noise
370
+ return sine_waves, uv, noise
371
+
372
+
373
+ class SourceModuleHnNSF(torch.nn.Module):
374
+ """SourceModule for hn-nsf
375
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
376
+ add_noise_std=0.003, voiced_threshod=0)
377
+ sampling_rate: sampling_rate in Hz
378
+ harmonic_num: number of harmonic above F0 (default: 0)
379
+ sine_amp: amplitude of sine source signal (default: 0.1)
380
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
381
+ note that amplitude of noise in unvoiced is decided
382
+ by sine_amp
383
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
384
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
385
+ F0_sampled (batchsize, length, 1)
386
+ Sine_source (batchsize, length, 1)
387
+ noise_source (batchsize, length 1)
388
+ uv (batchsize, length, 1)
389
+ """
390
+
391
+ def __init__(
392
+ self,
393
+ sampling_rate,
394
+ harmonic_num=0,
395
+ sine_amp=0.1,
396
+ add_noise_std=0.003,
397
+ voiced_threshod=0,
398
+ is_half=True,
399
+ ):
400
+ super(SourceModuleHnNSF, self).__init__()
401
+
402
+ self.sine_amp = sine_amp
403
+ self.noise_std = add_noise_std
404
+ self.is_half = is_half
405
+ # to produce sine waveforms
406
+ self.l_sin_gen = SineGen(
407
+ sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
408
+ )
409
+
410
+ # to merge source harmonics into a single excitation
411
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
412
+ self.l_tanh = torch.nn.Tanh()
413
+
414
+ def forward(self, x, upp=None):
415
+ sine_wavs, uv, _ = self.l_sin_gen(x, upp)
416
+ if self.is_half:
417
+ sine_wavs = sine_wavs.half()
418
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
419
+ return sine_merge, None, None # noise, uv
420
+
421
+
422
+ class GeneratorNSF(torch.nn.Module):
423
+ def __init__(
424
+ self,
425
+ initial_channel,
426
+ resblock,
427
+ resblock_kernel_sizes,
428
+ resblock_dilation_sizes,
429
+ upsample_rates,
430
+ upsample_initial_channel,
431
+ upsample_kernel_sizes,
432
+ gin_channels,
433
+ sr,
434
+ is_half=False,
435
+ ):
436
+ super(GeneratorNSF, self).__init__()
437
+ self.num_kernels = len(resblock_kernel_sizes)
438
+ self.num_upsamples = len(upsample_rates)
439
+
440
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
441
+ self.m_source = SourceModuleHnNSF(
442
+ sampling_rate=sr, harmonic_num=0, is_half=is_half
443
+ )
444
+ self.noise_convs = nn.ModuleList()
445
+ self.conv_pre = Conv1d(
446
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
447
+ )
448
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
449
+
450
+ self.ups = nn.ModuleList()
451
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
452
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
453
+ self.ups.append(
454
+ weight_norm(
455
+ ConvTranspose1d(
456
+ upsample_initial_channel // (2**i),
457
+ upsample_initial_channel // (2 ** (i + 1)),
458
+ k,
459
+ u,
460
+ padding=(k - u) // 2,
461
+ )
462
+ )
463
+ )
464
+ if i + 1 < len(upsample_rates):
465
+ stride_f0 = np.prod(upsample_rates[i + 1 :])
466
+ self.noise_convs.append(
467
+ Conv1d(
468
+ 1,
469
+ c_cur,
470
+ kernel_size=stride_f0 * 2,
471
+ stride=stride_f0,
472
+ padding=stride_f0 // 2,
473
+ )
474
+ )
475
+ else:
476
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
477
+
478
+ self.resblocks = nn.ModuleList()
479
+ for i in range(len(self.ups)):
480
+ ch = upsample_initial_channel // (2 ** (i + 1))
481
+ for j, (k, d) in enumerate(
482
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
483
+ ):
484
+ self.resblocks.append(resblock(ch, k, d))
485
+
486
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
487
+ self.ups.apply(init_weights)
488
+
489
+ if gin_channels != 0:
490
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
491
+
492
+ self.upp = np.prod(upsample_rates)
493
+
494
+ def forward(self, x, f0, g=None):
495
+ har_source, noi_source, uv = self.m_source(f0, self.upp)
496
+ har_source = har_source.transpose(1, 2)
497
+ x = self.conv_pre(x)
498
+ if g is not None:
499
+ x = x + self.cond(g)
500
+
501
+ for i in range(self.num_upsamples):
502
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
503
+ x = self.ups[i](x)
504
+ x_source = self.noise_convs[i](har_source)
505
+ x = x + x_source
506
+ xs = None
507
+ for j in range(self.num_kernels):
508
+ if xs is None:
509
+ xs = self.resblocks[i * self.num_kernels + j](x)
510
+ else:
511
+ xs += self.resblocks[i * self.num_kernels + j](x)
512
+ x = xs / self.num_kernels
513
+ x = F.leaky_relu(x)
514
+ x = self.conv_post(x)
515
+ x = torch.tanh(x)
516
+ return x
517
+
518
+ def remove_weight_norm(self):
519
+ for l in self.ups:
520
+ remove_weight_norm(l)
521
+ for l in self.resblocks:
522
+ l.remove_weight_norm()
523
+
524
+
525
+ sr2sr = {
526
+ "32k": 32000,
527
+ "40k": 40000,
528
+ "48k": 48000,
529
+ }
530
+
531
+
532
+ class SynthesizerTrnMs256NSFsid(nn.Module):
533
+ def __init__(
534
+ self,
535
+ spec_channels,
536
+ segment_size,
537
+ inter_channels,
538
+ hidden_channels,
539
+ filter_channels,
540
+ n_heads,
541
+ n_layers,
542
+ kernel_size,
543
+ p_dropout,
544
+ resblock,
545
+ resblock_kernel_sizes,
546
+ resblock_dilation_sizes,
547
+ upsample_rates,
548
+ upsample_initial_channel,
549
+ upsample_kernel_sizes,
550
+ spk_embed_dim,
551
+ gin_channels,
552
+ sr,
553
+ **kwargs
554
+ ):
555
+ super().__init__()
556
+ if type(sr) == type("strr"):
557
+ sr = sr2sr[sr]
558
+ self.spec_channels = spec_channels
559
+ self.inter_channels = inter_channels
560
+ self.hidden_channels = hidden_channels
561
+ self.filter_channels = filter_channels
562
+ self.n_heads = n_heads
563
+ self.n_layers = n_layers
564
+ self.kernel_size = kernel_size
565
+ self.p_dropout = p_dropout
566
+ self.resblock = resblock
567
+ self.resblock_kernel_sizes = resblock_kernel_sizes
568
+ self.resblock_dilation_sizes = resblock_dilation_sizes
569
+ self.upsample_rates = upsample_rates
570
+ self.upsample_initial_channel = upsample_initial_channel
571
+ self.upsample_kernel_sizes = upsample_kernel_sizes
572
+ self.segment_size = segment_size
573
+ self.gin_channels = gin_channels
574
+ # self.hop_length = hop_length#
575
+ self.spk_embed_dim = spk_embed_dim
576
+ self.enc_p = TextEncoder256(
577
+ inter_channels,
578
+ hidden_channels,
579
+ filter_channels,
580
+ n_heads,
581
+ n_layers,
582
+ kernel_size,
583
+ p_dropout,
584
+ )
585
+ self.dec = GeneratorNSF(
586
+ inter_channels,
587
+ resblock,
588
+ resblock_kernel_sizes,
589
+ resblock_dilation_sizes,
590
+ upsample_rates,
591
+ upsample_initial_channel,
592
+ upsample_kernel_sizes,
593
+ gin_channels=gin_channels,
594
+ sr=sr,
595
+ is_half=kwargs["is_half"],
596
+ )
597
+ self.enc_q = PosteriorEncoder(
598
+ spec_channels,
599
+ inter_channels,
600
+ hidden_channels,
601
+ 5,
602
+ 1,
603
+ 16,
604
+ gin_channels=gin_channels,
605
+ )
606
+ self.flow = ResidualCouplingBlock(
607
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
608
+ )
609
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
610
+ print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
611
+
612
+ def remove_weight_norm(self):
613
+ self.dec.remove_weight_norm()
614
+ self.flow.remove_weight_norm()
615
+ self.enc_q.remove_weight_norm()
616
+
617
+ def forward(
618
+ self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds
619
+ ): # Here ds is id, [bs,1]
620
+ # print(1,pitch.shape)#[bs,t]
621
+ g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1 is t, broadcast
622
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
623
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
624
+ z_p = self.flow(z, y_mask, g=g)
625
+ z_slice, ids_slice = commons.rand_slice_segments(
626
+ z, y_lengths, self.segment_size
627
+ )
628
+ # print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length)
629
+ pitchf = commons.slice_segments2(pitchf, ids_slice, self.segment_size)
630
+ # print(-2,pitchf.shape,z_slice.shape)
631
+ o = self.dec(z_slice, pitchf, g=g)
632
+ return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
633
+
634
+ def infer(self, phone, phone_lengths, pitch, nsff0, sid, rate=None):
635
+ g = self.emb_g(sid).unsqueeze(-1)
636
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
637
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
638
+ if rate:
639
+ head = int(z_p.shape[2] * rate)
640
+ z_p = z_p[:, :, -head:]
641
+ x_mask = x_mask[:, :, -head:]
642
+ nsff0 = nsff0[:, -head:]
643
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
644
+ o = self.dec(z * x_mask, nsff0, g=g)
645
+ return o, x_mask, (z, z_p, m_p, logs_p)
646
+
647
+
648
+ class SynthesizerTrnMs768NSFsid(nn.Module):
649
+ def __init__(
650
+ self,
651
+ spec_channels,
652
+ segment_size,
653
+ inter_channels,
654
+ hidden_channels,
655
+ filter_channels,
656
+ n_heads,
657
+ n_layers,
658
+ kernel_size,
659
+ p_dropout,
660
+ resblock,
661
+ resblock_kernel_sizes,
662
+ resblock_dilation_sizes,
663
+ upsample_rates,
664
+ upsample_initial_channel,
665
+ upsample_kernel_sizes,
666
+ spk_embed_dim,
667
+ gin_channels,
668
+ sr,
669
+ **kwargs
670
+ ):
671
+ super().__init__()
672
+ if type(sr) == type("strr"):
673
+ sr = sr2sr[sr]
674
+ self.spec_channels = spec_channels
675
+ self.inter_channels = inter_channels
676
+ self.hidden_channels = hidden_channels
677
+ self.filter_channels = filter_channels
678
+ self.n_heads = n_heads
679
+ self.n_layers = n_layers
680
+ self.kernel_size = kernel_size
681
+ self.p_dropout = p_dropout
682
+ self.resblock = resblock
683
+ self.resblock_kernel_sizes = resblock_kernel_sizes
684
+ self.resblock_dilation_sizes = resblock_dilation_sizes
685
+ self.upsample_rates = upsample_rates
686
+ self.upsample_initial_channel = upsample_initial_channel
687
+ self.upsample_kernel_sizes = upsample_kernel_sizes
688
+ self.segment_size = segment_size
689
+ self.gin_channels = gin_channels
690
+ # self.hop_length = hop_length#
691
+ self.spk_embed_dim = spk_embed_dim
692
+ self.enc_p = TextEncoder768(
693
+ inter_channels,
694
+ hidden_channels,
695
+ filter_channels,
696
+ n_heads,
697
+ n_layers,
698
+ kernel_size,
699
+ p_dropout,
700
+ )
701
+ self.dec = GeneratorNSF(
702
+ inter_channels,
703
+ resblock,
704
+ resblock_kernel_sizes,
705
+ resblock_dilation_sizes,
706
+ upsample_rates,
707
+ upsample_initial_channel,
708
+ upsample_kernel_sizes,
709
+ gin_channels=gin_channels,
710
+ sr=sr,
711
+ is_half=kwargs["is_half"],
712
+ )
713
+ self.enc_q = PosteriorEncoder(
714
+ spec_channels,
715
+ inter_channels,
716
+ hidden_channels,
717
+ 5,
718
+ 1,
719
+ 16,
720
+ gin_channels=gin_channels,
721
+ )
722
+ self.flow = ResidualCouplingBlock(
723
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
724
+ )
725
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
726
+ print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
727
+
728
+ def remove_weight_norm(self):
729
+ self.dec.remove_weight_norm()
730
+ self.flow.remove_weight_norm()
731
+ self.enc_q.remove_weight_norm()
732
+
733
+ def forward(
734
+ self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds
735
+ ): # Here ds is id,[bs,1]
736
+ # print(1,pitch.shape)#[bs,t]
737
+ g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1 is t, broadcast
738
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
739
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
740
+ z_p = self.flow(z, y_mask, g=g)
741
+ z_slice, ids_slice = commons.rand_slice_segments(
742
+ z, y_lengths, self.segment_size
743
+ )
744
+ # print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length)
745
+ pitchf = commons.slice_segments2(pitchf, ids_slice, self.segment_size)
746
+ # print(-2,pitchf.shape,z_slice.shape)
747
+ o = self.dec(z_slice, pitchf, g=g)
748
+ return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
749
+
750
+ def infer(self, phone, phone_lengths, pitch, nsff0, sid, rate=None):
751
+ g = self.emb_g(sid).unsqueeze(-1)
752
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
753
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
754
+ if rate:
755
+ head = int(z_p.shape[2] * rate)
756
+ z_p = z_p[:, :, -head:]
757
+ x_mask = x_mask[:, :, -head:]
758
+ nsff0 = nsff0[:, -head:]
759
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
760
+ o = self.dec(z * x_mask, nsff0, g=g)
761
+ return o, x_mask, (z, z_p, m_p, logs_p)
762
+
763
+
764
+ class SynthesizerTrnMs256NSFsid_nono(nn.Module):
765
+ def __init__(
766
+ self,
767
+ spec_channels,
768
+ segment_size,
769
+ inter_channels,
770
+ hidden_channels,
771
+ filter_channels,
772
+ n_heads,
773
+ n_layers,
774
+ kernel_size,
775
+ p_dropout,
776
+ resblock,
777
+ resblock_kernel_sizes,
778
+ resblock_dilation_sizes,
779
+ upsample_rates,
780
+ upsample_initial_channel,
781
+ upsample_kernel_sizes,
782
+ spk_embed_dim,
783
+ gin_channels,
784
+ sr=None,
785
+ **kwargs
786
+ ):
787
+ super().__init__()
788
+ self.spec_channels = spec_channels
789
+ self.inter_channels = inter_channels
790
+ self.hidden_channels = hidden_channels
791
+ self.filter_channels = filter_channels
792
+ self.n_heads = n_heads
793
+ self.n_layers = n_layers
794
+ self.kernel_size = kernel_size
795
+ self.p_dropout = p_dropout
796
+ self.resblock = resblock
797
+ self.resblock_kernel_sizes = resblock_kernel_sizes
798
+ self.resblock_dilation_sizes = resblock_dilation_sizes
799
+ self.upsample_rates = upsample_rates
800
+ self.upsample_initial_channel = upsample_initial_channel
801
+ self.upsample_kernel_sizes = upsample_kernel_sizes
802
+ self.segment_size = segment_size
803
+ self.gin_channels = gin_channels
804
+ # self.hop_length = hop_length#
805
+ self.spk_embed_dim = spk_embed_dim
806
+ self.enc_p = TextEncoder256(
807
+ inter_channels,
808
+ hidden_channels,
809
+ filter_channels,
810
+ n_heads,
811
+ n_layers,
812
+ kernel_size,
813
+ p_dropout,
814
+ f0=False,
815
+ )
816
+ self.dec = Generator(
817
+ inter_channels,
818
+ resblock,
819
+ resblock_kernel_sizes,
820
+ resblock_dilation_sizes,
821
+ upsample_rates,
822
+ upsample_initial_channel,
823
+ upsample_kernel_sizes,
824
+ gin_channels=gin_channels,
825
+ )
826
+ self.enc_q = PosteriorEncoder(
827
+ spec_channels,
828
+ inter_channels,
829
+ hidden_channels,
830
+ 5,
831
+ 1,
832
+ 16,
833
+ gin_channels=gin_channels,
834
+ )
835
+ self.flow = ResidualCouplingBlock(
836
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
837
+ )
838
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
839
+ print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
840
+
841
+ def remove_weight_norm(self):
842
+ self.dec.remove_weight_norm()
843
+ self.flow.remove_weight_norm()
844
+ self.enc_q.remove_weight_norm()
845
+
846
+ def forward(self, phone, phone_lengths, y, y_lengths, ds): # Here ds is id,[bs,1]
847
+ g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1 is t, broadcast
848
+ m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
849
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
850
+ z_p = self.flow(z, y_mask, g=g)
851
+ z_slice, ids_slice = commons.rand_slice_segments(
852
+ z, y_lengths, self.segment_size
853
+ )
854
+ o = self.dec(z_slice, g=g)
855
+ return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
856
+
857
+ def infer(self, phone, phone_lengths, sid, rate=None):
858
+ g = self.emb_g(sid).unsqueeze(-1)
859
+ m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
860
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
861
+ if rate:
862
+ head = int(z_p.shape[2] * rate)
863
+ z_p = z_p[:, :, -head:]
864
+ x_mask = x_mask[:, :, -head:]
865
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
866
+ o = self.dec(z * x_mask, g=g)
867
+ return o, x_mask, (z, z_p, m_p, logs_p)
868
+
869
+
870
+ class SynthesizerTrnMs768NSFsid_nono(nn.Module):
871
+ def __init__(
872
+ self,
873
+ spec_channels,
874
+ segment_size,
875
+ inter_channels,
876
+ hidden_channels,
877
+ filter_channels,
878
+ n_heads,
879
+ n_layers,
880
+ kernel_size,
881
+ p_dropout,
882
+ resblock,
883
+ resblock_kernel_sizes,
884
+ resblock_dilation_sizes,
885
+ upsample_rates,
886
+ upsample_initial_channel,
887
+ upsample_kernel_sizes,
888
+ spk_embed_dim,
889
+ gin_channels,
890
+ sr=None,
891
+ **kwargs
892
+ ):
893
+ super().__init__()
894
+ self.spec_channels = spec_channels
895
+ self.inter_channels = inter_channels
896
+ self.hidden_channels = hidden_channels
897
+ self.filter_channels = filter_channels
898
+ self.n_heads = n_heads
899
+ self.n_layers = n_layers
900
+ self.kernel_size = kernel_size
901
+ self.p_dropout = p_dropout
902
+ self.resblock = resblock
903
+ self.resblock_kernel_sizes = resblock_kernel_sizes
904
+ self.resblock_dilation_sizes = resblock_dilation_sizes
905
+ self.upsample_rates = upsample_rates
906
+ self.upsample_initial_channel = upsample_initial_channel
907
+ self.upsample_kernel_sizes = upsample_kernel_sizes
908
+ self.segment_size = segment_size
909
+ self.gin_channels = gin_channels
910
+ # self.hop_length = hop_length#
911
+ self.spk_embed_dim = spk_embed_dim
912
+ self.enc_p = TextEncoder768(
913
+ inter_channels,
914
+ hidden_channels,
915
+ filter_channels,
916
+ n_heads,
917
+ n_layers,
918
+ kernel_size,
919
+ p_dropout,
920
+ f0=False,
921
+ )
922
+ self.dec = Generator(
923
+ inter_channels,
924
+ resblock,
925
+ resblock_kernel_sizes,
926
+ resblock_dilation_sizes,
927
+ upsample_rates,
928
+ upsample_initial_channel,
929
+ upsample_kernel_sizes,
930
+ gin_channels=gin_channels,
931
+ )
932
+ self.enc_q = PosteriorEncoder(
933
+ spec_channels,
934
+ inter_channels,
935
+ hidden_channels,
936
+ 5,
937
+ 1,
938
+ 16,
939
+ gin_channels=gin_channels,
940
+ )
941
+ self.flow = ResidualCouplingBlock(
942
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
943
+ )
944
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
945
+ print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
946
+
947
+ def remove_weight_norm(self):
948
+ self.dec.remove_weight_norm()
949
+ self.flow.remove_weight_norm()
950
+ self.enc_q.remove_weight_norm()
951
+
952
+ def forward(self, phone, phone_lengths, y, y_lengths, ds): # Here ds is id,[bs,1]
953
+ g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1 is t, broadcast
954
+ m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
955
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
956
+ z_p = self.flow(z, y_mask, g=g)
957
+ z_slice, ids_slice = commons.rand_slice_segments(
958
+ z, y_lengths, self.segment_size
959
+ )
960
+ o = self.dec(z_slice, g=g)
961
+ return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
962
+
963
+ def infer(self, phone, phone_lengths, sid, rate=None):
964
+ g = self.emb_g(sid).unsqueeze(-1)
965
+ m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
966
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
967
+ if rate:
968
+ head = int(z_p.shape[2] * rate)
969
+ z_p = z_p[:, :, -head:]
970
+ x_mask = x_mask[:, :, -head:]
971
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
972
+ o = self.dec(z * x_mask, g=g)
973
+ return o, x_mask, (z, z_p, m_p, logs_p)
974
+
975
+
976
+ class MultiPeriodDiscriminator(torch.nn.Module):
977
+ def __init__(self, use_spectral_norm=False):
978
+ super(MultiPeriodDiscriminator, self).__init__()
979
+ periods = [2, 3, 5, 7, 11, 17]
980
+ # periods = [3, 5, 7, 11, 17, 23, 37]
981
+
982
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
983
+ discs = discs + [
984
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
985
+ ]
986
+ self.discriminators = nn.ModuleList(discs)
987
+
988
+ def forward(self, y, y_hat):
989
+ y_d_rs = [] #
990
+ y_d_gs = []
991
+ fmap_rs = []
992
+ fmap_gs = []
993
+ for i, d in enumerate(self.discriminators):
994
+ y_d_r, fmap_r = d(y)
995
+ y_d_g, fmap_g = d(y_hat)
996
+ # for j in range(len(fmap_r)):
997
+ # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
998
+ y_d_rs.append(y_d_r)
999
+ y_d_gs.append(y_d_g)
1000
+ fmap_rs.append(fmap_r)
1001
+ fmap_gs.append(fmap_g)
1002
+
1003
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
1004
+
1005
+
1006
+ class MultiPeriodDiscriminatorV2(torch.nn.Module):
1007
+ def __init__(self, use_spectral_norm=False):
1008
+ super(MultiPeriodDiscriminatorV2, self).__init__()
1009
+ # periods = [2, 3, 5, 7, 11, 17]
1010
+ periods = [2, 3, 5, 7, 11, 17, 23, 37]
1011
+
1012
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
1013
+ discs = discs + [
1014
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
1015
+ ]
1016
+ self.discriminators = nn.ModuleList(discs)
1017
+
1018
+ def forward(self, y, y_hat):
1019
+ y_d_rs = [] #
1020
+ y_d_gs = []
1021
+ fmap_rs = []
1022
+ fmap_gs = []
1023
+ for i, d in enumerate(self.discriminators):
1024
+ y_d_r, fmap_r = d(y)
1025
+ y_d_g, fmap_g = d(y_hat)
1026
+ # for j in range(len(fmap_r)):
1027
+ # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
1028
+ y_d_rs.append(y_d_r)
1029
+ y_d_gs.append(y_d_g)
1030
+ fmap_rs.append(fmap_r)
1031
+ fmap_gs.append(fmap_g)
1032
+
1033
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
1034
+
1035
+
1036
+ class DiscriminatorS(torch.nn.Module):
1037
+ def __init__(self, use_spectral_norm=False):
1038
+ super(DiscriminatorS, self).__init__()
1039
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
1040
+ self.convs = nn.ModuleList(
1041
+ [
1042
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
1043
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
1044
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
1045
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
1046
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
1047
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
1048
+ ]
1049
+ )
1050
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
1051
+
1052
+ def forward(self, x):
1053
+ fmap = []
1054
+
1055
+ for l in self.convs:
1056
+ x = l(x)
1057
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
1058
+ fmap.append(x)
1059
+ x = self.conv_post(x)
1060
+ fmap.append(x)
1061
+ x = torch.flatten(x, 1, -1)
1062
+
1063
+ return x, fmap
1064
+
1065
+
1066
+ class DiscriminatorP(torch.nn.Module):
1067
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
1068
+ super(DiscriminatorP, self).__init__()
1069
+ self.period = period
1070
+ self.use_spectral_norm = use_spectral_norm
1071
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
1072
+ self.convs = nn.ModuleList(
1073
+ [
1074
+ norm_f(
1075
+ Conv2d(
1076
+ 1,
1077
+ 32,
1078
+ (kernel_size, 1),
1079
+ (stride, 1),
1080
+ padding=(get_padding(kernel_size, 1), 0),
1081
+ )
1082
+ ),
1083
+ norm_f(
1084
+ Conv2d(
1085
+ 32,
1086
+ 128,
1087
+ (kernel_size, 1),
1088
+ (stride, 1),
1089
+ padding=(get_padding(kernel_size, 1), 0),
1090
+ )
1091
+ ),
1092
+ norm_f(
1093
+ Conv2d(
1094
+ 128,
1095
+ 512,
1096
+ (kernel_size, 1),
1097
+ (stride, 1),
1098
+ padding=(get_padding(kernel_size, 1), 0),
1099
+ )
1100
+ ),
1101
+ norm_f(
1102
+ Conv2d(
1103
+ 512,
1104
+ 1024,
1105
+ (kernel_size, 1),
1106
+ (stride, 1),
1107
+ padding=(get_padding(kernel_size, 1), 0),
1108
+ )
1109
+ ),
1110
+ norm_f(
1111
+ Conv2d(
1112
+ 1024,
1113
+ 1024,
1114
+ (kernel_size, 1),
1115
+ 1,
1116
+ padding=(get_padding(kernel_size, 1), 0),
1117
+ )
1118
+ ),
1119
+ ]
1120
+ )
1121
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
1122
+
1123
+ def forward(self, x):
1124
+ fmap = []
1125
+
1126
+ # 1d to 2d
1127
+ b, c, t = x.shape
1128
+ if t % self.period != 0: # pad first
1129
+ n_pad = self.period - (t % self.period)
1130
+ x = F.pad(x, (0, n_pad), "reflect")
1131
+ t = t + n_pad
1132
+ x = x.view(b, c, t // self.period, self.period)
1133
+
1134
+ for l in self.convs:
1135
+ x = l(x)
1136
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
1137
+ fmap.append(x)
1138
+ x = self.conv_post(x)
1139
+ fmap.append(x)
1140
+ x = torch.flatten(x, 1, -1)
1141
+
1142
+ return x, fmap
lib/infer_pack/modules.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import numpy as np
4
+ import scipy
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
10
+ from torch.nn.utils import weight_norm, remove_weight_norm
11
+
12
+ from lib.infer_pack import commons
13
+ from lib.infer_pack.commons import init_weights, get_padding
14
+ from lib.infer_pack.transforms import piecewise_rational_quadratic_transform
15
+
16
+
17
+ LRELU_SLOPE = 0.1
18
+
19
+
20
+ class LayerNorm(nn.Module):
21
+ def __init__(self, channels, eps=1e-5):
22
+ super().__init__()
23
+ self.channels = channels
24
+ self.eps = eps
25
+
26
+ self.gamma = nn.Parameter(torch.ones(channels))
27
+ self.beta = nn.Parameter(torch.zeros(channels))
28
+
29
+ def forward(self, x):
30
+ x = x.transpose(1, -1)
31
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
32
+ return x.transpose(1, -1)
33
+
34
+
35
+ class ConvReluNorm(nn.Module):
36
+ def __init__(
37
+ self,
38
+ in_channels,
39
+ hidden_channels,
40
+ out_channels,
41
+ kernel_size,
42
+ n_layers,
43
+ p_dropout,
44
+ ):
45
+ super().__init__()
46
+ self.in_channels = in_channels
47
+ self.hidden_channels = hidden_channels
48
+ self.out_channels = out_channels
49
+ self.kernel_size = kernel_size
50
+ self.n_layers = n_layers
51
+ self.p_dropout = p_dropout
52
+ assert n_layers > 1, "Number of layers should be larger than 0."
53
+
54
+ self.conv_layers = nn.ModuleList()
55
+ self.norm_layers = nn.ModuleList()
56
+ self.conv_layers.append(
57
+ nn.Conv1d(
58
+ in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
59
+ )
60
+ )
61
+ self.norm_layers.append(LayerNorm(hidden_channels))
62
+ self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
63
+ for _ in range(n_layers - 1):
64
+ self.conv_layers.append(
65
+ nn.Conv1d(
66
+ hidden_channels,
67
+ hidden_channels,
68
+ kernel_size,
69
+ padding=kernel_size // 2,
70
+ )
71
+ )
72
+ self.norm_layers.append(LayerNorm(hidden_channels))
73
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
74
+ self.proj.weight.data.zero_()
75
+ self.proj.bias.data.zero_()
76
+
77
+ def forward(self, x, x_mask):
78
+ x_org = x
79
+ for i in range(self.n_layers):
80
+ x = self.conv_layers[i](x * x_mask)
81
+ x = self.norm_layers[i](x)
82
+ x = self.relu_drop(x)
83
+ x = x_org + self.proj(x)
84
+ return x * x_mask
85
+
86
+
87
+ class DDSConv(nn.Module):
88
+ """
89
+ Dialted and Depth-Separable Convolution
90
+ """
91
+
92
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
93
+ super().__init__()
94
+ self.channels = channels
95
+ self.kernel_size = kernel_size
96
+ self.n_layers = n_layers
97
+ self.p_dropout = p_dropout
98
+
99
+ self.drop = nn.Dropout(p_dropout)
100
+ self.convs_sep = nn.ModuleList()
101
+ self.convs_1x1 = nn.ModuleList()
102
+ self.norms_1 = nn.ModuleList()
103
+ self.norms_2 = nn.ModuleList()
104
+ for i in range(n_layers):
105
+ dilation = kernel_size**i
106
+ padding = (kernel_size * dilation - dilation) // 2
107
+ self.convs_sep.append(
108
+ nn.Conv1d(
109
+ channels,
110
+ channels,
111
+ kernel_size,
112
+ groups=channels,
113
+ dilation=dilation,
114
+ padding=padding,
115
+ )
116
+ )
117
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
118
+ self.norms_1.append(LayerNorm(channels))
119
+ self.norms_2.append(LayerNorm(channels))
120
+
121
+ def forward(self, x, x_mask, g=None):
122
+ if g is not None:
123
+ x = x + g
124
+ for i in range(self.n_layers):
125
+ y = self.convs_sep[i](x * x_mask)
126
+ y = self.norms_1[i](y)
127
+ y = F.gelu(y)
128
+ y = self.convs_1x1[i](y)
129
+ y = self.norms_2[i](y)
130
+ y = F.gelu(y)
131
+ y = self.drop(y)
132
+ x = x + y
133
+ return x * x_mask
134
+
135
+
136
+ class WN(torch.nn.Module):
137
+ def __init__(
138
+ self,
139
+ hidden_channels,
140
+ kernel_size,
141
+ dilation_rate,
142
+ n_layers,
143
+ gin_channels=0,
144
+ p_dropout=0,
145
+ ):
146
+ super(WN, self).__init__()
147
+ assert kernel_size % 2 == 1
148
+ self.hidden_channels = hidden_channels
149
+ self.kernel_size = (kernel_size,)
150
+ self.dilation_rate = dilation_rate
151
+ self.n_layers = n_layers
152
+ self.gin_channels = gin_channels
153
+ self.p_dropout = p_dropout
154
+
155
+ self.in_layers = torch.nn.ModuleList()
156
+ self.res_skip_layers = torch.nn.ModuleList()
157
+ self.drop = nn.Dropout(p_dropout)
158
+
159
+ if gin_channels != 0:
160
+ cond_layer = torch.nn.Conv1d(
161
+ gin_channels, 2 * hidden_channels * n_layers, 1
162
+ )
163
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
164
+
165
+ for i in range(n_layers):
166
+ dilation = dilation_rate**i
167
+ padding = int((kernel_size * dilation - dilation) / 2)
168
+ in_layer = torch.nn.Conv1d(
169
+ hidden_channels,
170
+ 2 * hidden_channels,
171
+ kernel_size,
172
+ dilation=dilation,
173
+ padding=padding,
174
+ )
175
+ in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
176
+ self.in_layers.append(in_layer)
177
+
178
+ # last one is not necessary
179
+ if i < n_layers - 1:
180
+ res_skip_channels = 2 * hidden_channels
181
+ else:
182
+ res_skip_channels = hidden_channels
183
+
184
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
185
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
186
+ self.res_skip_layers.append(res_skip_layer)
187
+
188
+ def forward(self, x, x_mask, g=None, **kwargs):
189
+ output = torch.zeros_like(x)
190
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
191
+
192
+ if g is not None:
193
+ g = self.cond_layer(g)
194
+
195
+ for i in range(self.n_layers):
196
+ x_in = self.in_layers[i](x)
197
+ if g is not None:
198
+ cond_offset = i * 2 * self.hidden_channels
199
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
200
+ else:
201
+ g_l = torch.zeros_like(x_in)
202
+
203
+ acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
204
+ acts = self.drop(acts)
205
+
206
+ res_skip_acts = self.res_skip_layers[i](acts)
207
+ if i < self.n_layers - 1:
208
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
209
+ x = (x + res_acts) * x_mask
210
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
211
+ else:
212
+ output = output + res_skip_acts
213
+ return output * x_mask
214
+
215
+ def remove_weight_norm(self):
216
+ if self.gin_channels != 0:
217
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
218
+ for l in self.in_layers:
219
+ torch.nn.utils.remove_weight_norm(l)
220
+ for l in self.res_skip_layers:
221
+ torch.nn.utils.remove_weight_norm(l)
222
+
223
+
224
+ class ResBlock1(torch.nn.Module):
225
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
226
+ super(ResBlock1, self).__init__()
227
+ self.convs1 = nn.ModuleList(
228
+ [
229
+ weight_norm(
230
+ Conv1d(
231
+ channels,
232
+ channels,
233
+ kernel_size,
234
+ 1,
235
+ dilation=dilation[0],
236
+ padding=get_padding(kernel_size, dilation[0]),
237
+ )
238
+ ),
239
+ weight_norm(
240
+ Conv1d(
241
+ channels,
242
+ channels,
243
+ kernel_size,
244
+ 1,
245
+ dilation=dilation[1],
246
+ padding=get_padding(kernel_size, dilation[1]),
247
+ )
248
+ ),
249
+ weight_norm(
250
+ Conv1d(
251
+ channels,
252
+ channels,
253
+ kernel_size,
254
+ 1,
255
+ dilation=dilation[2],
256
+ padding=get_padding(kernel_size, dilation[2]),
257
+ )
258
+ ),
259
+ ]
260
+ )
261
+ self.convs1.apply(init_weights)
262
+
263
+ self.convs2 = nn.ModuleList(
264
+ [
265
+ weight_norm(
266
+ Conv1d(
267
+ channels,
268
+ channels,
269
+ kernel_size,
270
+ 1,
271
+ dilation=1,
272
+ padding=get_padding(kernel_size, 1),
273
+ )
274
+ ),
275
+ weight_norm(
276
+ Conv1d(
277
+ channels,
278
+ channels,
279
+ kernel_size,
280
+ 1,
281
+ dilation=1,
282
+ padding=get_padding(kernel_size, 1),
283
+ )
284
+ ),
285
+ weight_norm(
286
+ Conv1d(
287
+ channels,
288
+ channels,
289
+ kernel_size,
290
+ 1,
291
+ dilation=1,
292
+ padding=get_padding(kernel_size, 1),
293
+ )
294
+ ),
295
+ ]
296
+ )
297
+ self.convs2.apply(init_weights)
298
+
299
+ def forward(self, x, x_mask=None):
300
+ for c1, c2 in zip(self.convs1, self.convs2):
301
+ xt = F.leaky_relu(x, LRELU_SLOPE)
302
+ if x_mask is not None:
303
+ xt = xt * x_mask
304
+ xt = c1(xt)
305
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
306
+ if x_mask is not None:
307
+ xt = xt * x_mask
308
+ xt = c2(xt)
309
+ x = xt + x
310
+ if x_mask is not None:
311
+ x = x * x_mask
312
+ return x
313
+
314
+ def remove_weight_norm(self):
315
+ for l in self.convs1:
316
+ remove_weight_norm(l)
317
+ for l in self.convs2:
318
+ remove_weight_norm(l)
319
+
320
+
321
+ class ResBlock2(torch.nn.Module):
322
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
323
+ super(ResBlock2, self).__init__()
324
+ self.convs = nn.ModuleList(
325
+ [
326
+ weight_norm(
327
+ Conv1d(
328
+ channels,
329
+ channels,
330
+ kernel_size,
331
+ 1,
332
+ dilation=dilation[0],
333
+ padding=get_padding(kernel_size, dilation[0]),
334
+ )
335
+ ),
336
+ weight_norm(
337
+ Conv1d(
338
+ channels,
339
+ channels,
340
+ kernel_size,
341
+ 1,
342
+ dilation=dilation[1],
343
+ padding=get_padding(kernel_size, dilation[1]),
344
+ )
345
+ ),
346
+ ]
347
+ )
348
+ self.convs.apply(init_weights)
349
+
350
+ def forward(self, x, x_mask=None):
351
+ for c in self.convs:
352
+ xt = F.leaky_relu(x, LRELU_SLOPE)
353
+ if x_mask is not None:
354
+ xt = xt * x_mask
355
+ xt = c(xt)
356
+ x = xt + x
357
+ if x_mask is not None:
358
+ x = x * x_mask
359
+ return x
360
+
361
+ def remove_weight_norm(self):
362
+ for l in self.convs:
363
+ remove_weight_norm(l)
364
+
365
+
366
+ class Log(nn.Module):
367
+ def forward(self, x, x_mask, reverse=False, **kwargs):
368
+ if not reverse:
369
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
370
+ logdet = torch.sum(-y, [1, 2])
371
+ return y, logdet
372
+ else:
373
+ x = torch.exp(x) * x_mask
374
+ return x
375
+
376
+
377
+ class Flip(nn.Module):
378
+ def forward(self, x, *args, reverse=False, **kwargs):
379
+ x = torch.flip(x, [1])
380
+ if not reverse:
381
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
382
+ return x, logdet
383
+ else:
384
+ return x
385
+
386
+
387
+ class ElementwiseAffine(nn.Module):
388
+ def __init__(self, channels):
389
+ super().__init__()
390
+ self.channels = channels
391
+ self.m = nn.Parameter(torch.zeros(channels, 1))
392
+ self.logs = nn.Parameter(torch.zeros(channels, 1))
393
+
394
+ def forward(self, x, x_mask, reverse=False, **kwargs):
395
+ if not reverse:
396
+ y = self.m + torch.exp(self.logs) * x
397
+ y = y * x_mask
398
+ logdet = torch.sum(self.logs * x_mask, [1, 2])
399
+ return y, logdet
400
+ else:
401
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
402
+ return x
403
+
404
+
405
+ class ResidualCouplingLayer(nn.Module):
406
+ def __init__(
407
+ self,
408
+ channels,
409
+ hidden_channels,
410
+ kernel_size,
411
+ dilation_rate,
412
+ n_layers,
413
+ p_dropout=0,
414
+ gin_channels=0,
415
+ mean_only=False,
416
+ ):
417
+ assert channels % 2 == 0, "channels should be divisible by 2"
418
+ super().__init__()
419
+ self.channels = channels
420
+ self.hidden_channels = hidden_channels
421
+ self.kernel_size = kernel_size
422
+ self.dilation_rate = dilation_rate
423
+ self.n_layers = n_layers
424
+ self.half_channels = channels // 2
425
+ self.mean_only = mean_only
426
+
427
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
428
+ self.enc = WN(
429
+ hidden_channels,
430
+ kernel_size,
431
+ dilation_rate,
432
+ n_layers,
433
+ p_dropout=p_dropout,
434
+ gin_channels=gin_channels,
435
+ )
436
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
437
+ self.post.weight.data.zero_()
438
+ self.post.bias.data.zero_()
439
+
440
+ def forward(self, x, x_mask, g=None, reverse=False):
441
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
442
+ h = self.pre(x0) * x_mask
443
+ h = self.enc(h, x_mask, g=g)
444
+ stats = self.post(h) * x_mask
445
+ if not self.mean_only:
446
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
447
+ else:
448
+ m = stats
449
+ logs = torch.zeros_like(m)
450
+
451
+ if not reverse:
452
+ x1 = m + x1 * torch.exp(logs) * x_mask
453
+ x = torch.cat([x0, x1], 1)
454
+ logdet = torch.sum(logs, [1, 2])
455
+ return x, logdet
456
+ else:
457
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
458
+ x = torch.cat([x0, x1], 1)
459
+ return x
460
+
461
+ def remove_weight_norm(self):
462
+ self.enc.remove_weight_norm()
463
+
464
+
465
+ class ConvFlow(nn.Module):
466
+ def __init__(
467
+ self,
468
+ in_channels,
469
+ filter_channels,
470
+ kernel_size,
471
+ n_layers,
472
+ num_bins=10,
473
+ tail_bound=5.0,
474
+ ):
475
+ super().__init__()
476
+ self.in_channels = in_channels
477
+ self.filter_channels = filter_channels
478
+ self.kernel_size = kernel_size
479
+ self.n_layers = n_layers
480
+ self.num_bins = num_bins
481
+ self.tail_bound = tail_bound
482
+ self.half_channels = in_channels // 2
483
+
484
+ self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
485
+ self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
486
+ self.proj = nn.Conv1d(
487
+ filter_channels, self.half_channels * (num_bins * 3 - 1), 1
488
+ )
489
+ self.proj.weight.data.zero_()
490
+ self.proj.bias.data.zero_()
491
+
492
+ def forward(self, x, x_mask, g=None, reverse=False):
493
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
494
+ h = self.pre(x0)
495
+ h = self.convs(h, x_mask, g=g)
496
+ h = self.proj(h) * x_mask
497
+
498
+ b, c, t = x0.shape
499
+ h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
500
+
501
+ unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
502
+ unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
503
+ self.filter_channels
504
+ )
505
+ unnormalized_derivatives = h[..., 2 * self.num_bins :]
506
+
507
+ x1, logabsdet = piecewise_rational_quadratic_transform(
508
+ x1,
509
+ unnormalized_widths,
510
+ unnormalized_heights,
511
+ unnormalized_derivatives,
512
+ inverse=reverse,
513
+ tails="linear",
514
+ tail_bound=self.tail_bound,
515
+ )
516
+
517
+ x = torch.cat([x0, x1], 1) * x_mask
518
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
519
+ if not reverse:
520
+ return x, logdet
521
+ else:
522
+ return x
lib/infer_pack/transforms.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+
4
+ import numpy as np
5
+
6
+
7
+ DEFAULT_MIN_BIN_WIDTH = 1e-3
8
+ DEFAULT_MIN_BIN_HEIGHT = 1e-3
9
+ DEFAULT_MIN_DERIVATIVE = 1e-3
10
+
11
+
12
+ def piecewise_rational_quadratic_transform(
13
+ inputs,
14
+ unnormalized_widths,
15
+ unnormalized_heights,
16
+ unnormalized_derivatives,
17
+ inverse=False,
18
+ tails=None,
19
+ tail_bound=1.0,
20
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
21
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
22
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
23
+ ):
24
+ if tails is None:
25
+ spline_fn = rational_quadratic_spline
26
+ spline_kwargs = {}
27
+ else:
28
+ spline_fn = unconstrained_rational_quadratic_spline
29
+ spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
30
+
31
+ outputs, logabsdet = spline_fn(
32
+ inputs=inputs,
33
+ unnormalized_widths=unnormalized_widths,
34
+ unnormalized_heights=unnormalized_heights,
35
+ unnormalized_derivatives=unnormalized_derivatives,
36
+ inverse=inverse,
37
+ min_bin_width=min_bin_width,
38
+ min_bin_height=min_bin_height,
39
+ min_derivative=min_derivative,
40
+ **spline_kwargs
41
+ )
42
+ return outputs, logabsdet
43
+
44
+
45
+ def searchsorted(bin_locations, inputs, eps=1e-6):
46
+ bin_locations[..., -1] += eps
47
+ return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
48
+
49
+
50
+ def unconstrained_rational_quadratic_spline(
51
+ inputs,
52
+ unnormalized_widths,
53
+ unnormalized_heights,
54
+ unnormalized_derivatives,
55
+ inverse=False,
56
+ tails="linear",
57
+ tail_bound=1.0,
58
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
59
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
60
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
61
+ ):
62
+ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
63
+ outside_interval_mask = ~inside_interval_mask
64
+
65
+ outputs = torch.zeros_like(inputs)
66
+ logabsdet = torch.zeros_like(inputs)
67
+
68
+ if tails == "linear":
69
+ unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
70
+ constant = np.log(np.exp(1 - min_derivative) - 1)
71
+ unnormalized_derivatives[..., 0] = constant
72
+ unnormalized_derivatives[..., -1] = constant
73
+
74
+ outputs[outside_interval_mask] = inputs[outside_interval_mask]
75
+ logabsdet[outside_interval_mask] = 0
76
+ else:
77
+ raise RuntimeError("{} tails are not implemented.".format(tails))
78
+
79
+ (
80
+ outputs[inside_interval_mask],
81
+ logabsdet[inside_interval_mask],
82
+ ) = rational_quadratic_spline(
83
+ inputs=inputs[inside_interval_mask],
84
+ unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
85
+ unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
86
+ unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
87
+ inverse=inverse,
88
+ left=-tail_bound,
89
+ right=tail_bound,
90
+ bottom=-tail_bound,
91
+ top=tail_bound,
92
+ min_bin_width=min_bin_width,
93
+ min_bin_height=min_bin_height,
94
+ min_derivative=min_derivative,
95
+ )
96
+
97
+ return outputs, logabsdet
98
+
99
+
100
+ def rational_quadratic_spline(
101
+ inputs,
102
+ unnormalized_widths,
103
+ unnormalized_heights,
104
+ unnormalized_derivatives,
105
+ inverse=False,
106
+ left=0.0,
107
+ right=1.0,
108
+ bottom=0.0,
109
+ top=1.0,
110
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
111
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
112
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
113
+ ):
114
+ if torch.min(inputs) < left or torch.max(inputs) > right:
115
+ raise ValueError("Input to a transform is not within its domain")
116
+
117
+ num_bins = unnormalized_widths.shape[-1]
118
+
119
+ if min_bin_width * num_bins > 1.0:
120
+ raise ValueError("Minimal bin width too large for the number of bins")
121
+ if min_bin_height * num_bins > 1.0:
122
+ raise ValueError("Minimal bin height too large for the number of bins")
123
+
124
+ widths = F.softmax(unnormalized_widths, dim=-1)
125
+ widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
126
+ cumwidths = torch.cumsum(widths, dim=-1)
127
+ cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
128
+ cumwidths = (right - left) * cumwidths + left
129
+ cumwidths[..., 0] = left
130
+ cumwidths[..., -1] = right
131
+ widths = cumwidths[..., 1:] - cumwidths[..., :-1]
132
+
133
+ derivatives = min_derivative + F.softplus(unnormalized_derivatives)
134
+
135
+ heights = F.softmax(unnormalized_heights, dim=-1)
136
+ heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
137
+ cumheights = torch.cumsum(heights, dim=-1)
138
+ cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
139
+ cumheights = (top - bottom) * cumheights + bottom
140
+ cumheights[..., 0] = bottom
141
+ cumheights[..., -1] = top
142
+ heights = cumheights[..., 1:] - cumheights[..., :-1]
143
+
144
+ if inverse:
145
+ bin_idx = searchsorted(cumheights, inputs)[..., None]
146
+ else:
147
+ bin_idx = searchsorted(cumwidths, inputs)[..., None]
148
+
149
+ input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
150
+ input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
151
+
152
+ input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
153
+ delta = heights / widths
154
+ input_delta = delta.gather(-1, bin_idx)[..., 0]
155
+
156
+ input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
157
+ input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
158
+
159
+ input_heights = heights.gather(-1, bin_idx)[..., 0]
160
+
161
+ if inverse:
162
+ a = (inputs - input_cumheights) * (
163
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
164
+ ) + input_heights * (input_delta - input_derivatives)
165
+ b = input_heights * input_derivatives - (inputs - input_cumheights) * (
166
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
167
+ )
168
+ c = -input_delta * (inputs - input_cumheights)
169
+
170
+ discriminant = b.pow(2) - 4 * a * c
171
+ assert (discriminant >= 0).all()
172
+
173
+ root = (2 * c) / (-b - torch.sqrt(discriminant))
174
+ outputs = root * input_bin_widths + input_cumwidths
175
+
176
+ theta_one_minus_theta = root * (1 - root)
177
+ denominator = input_delta + (
178
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
179
+ * theta_one_minus_theta
180
+ )
181
+ derivative_numerator = input_delta.pow(2) * (
182
+ input_derivatives_plus_one * root.pow(2)
183
+ + 2 * input_delta * theta_one_minus_theta
184
+ + input_derivatives * (1 - root).pow(2)
185
+ )
186
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
187
+
188
+ return outputs, -logabsdet
189
+ else:
190
+ theta = (inputs - input_cumwidths) / input_bin_widths
191
+ theta_one_minus_theta = theta * (1 - theta)
192
+
193
+ numerator = input_heights * (
194
+ input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
195
+ )
196
+ denominator = input_delta + (
197
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
198
+ * theta_one_minus_theta
199
+ )
200
+ outputs = input_cumheights + numerator / denominator
201
+
202
+ derivative_numerator = input_delta.pow(2) * (
203
+ input_derivatives_plus_one * theta.pow(2)
204
+ + 2 * input_delta * theta_one_minus_theta
205
+ + input_derivatives * (1 - theta).pow(2)
206
+ )
207
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
208
+
209
+ return outputs, logabsdet
lib/rmvpe.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, numpy as np
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+
7
+ class BiGRU(nn.Module):
8
+ def __init__(self, input_features, hidden_features, num_layers):
9
+ super(BiGRU, self).__init__()
10
+ self.gru = nn.GRU(
11
+ input_features,
12
+ hidden_features,
13
+ num_layers=num_layers,
14
+ batch_first=True,
15
+ bidirectional=True,
16
+ )
17
+
18
+ def forward(self, x):
19
+ return self.gru(x)[0]
20
+
21
+
22
+ class ConvBlockRes(nn.Module):
23
+ def __init__(self, in_channels, out_channels, momentum=0.01):
24
+ super(ConvBlockRes, self).__init__()
25
+ self.conv = nn.Sequential(
26
+ nn.Conv2d(
27
+ in_channels=in_channels,
28
+ out_channels=out_channels,
29
+ kernel_size=(3, 3),
30
+ stride=(1, 1),
31
+ padding=(1, 1),
32
+ bias=False,
33
+ ),
34
+ nn.BatchNorm2d(out_channels, momentum=momentum),
35
+ nn.ReLU(),
36
+ nn.Conv2d(
37
+ in_channels=out_channels,
38
+ out_channels=out_channels,
39
+ kernel_size=(3, 3),
40
+ stride=(1, 1),
41
+ padding=(1, 1),
42
+ bias=False,
43
+ ),
44
+ nn.BatchNorm2d(out_channels, momentum=momentum),
45
+ nn.ReLU(),
46
+ )
47
+ if in_channels != out_channels:
48
+ self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
49
+ self.is_shortcut = True
50
+ else:
51
+ self.is_shortcut = False
52
+
53
+ def forward(self, x):
54
+ if self.is_shortcut:
55
+ return self.conv(x) + self.shortcut(x)
56
+ else:
57
+ return self.conv(x) + x
58
+
59
+
60
+ class Encoder(nn.Module):
61
+ def __init__(
62
+ self,
63
+ in_channels,
64
+ in_size,
65
+ n_encoders,
66
+ kernel_size,
67
+ n_blocks,
68
+ out_channels=16,
69
+ momentum=0.01,
70
+ ):
71
+ super(Encoder, self).__init__()
72
+ self.n_encoders = n_encoders
73
+ self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
74
+ self.layers = nn.ModuleList()
75
+ self.latent_channels = []
76
+ for i in range(self.n_encoders):
77
+ self.layers.append(
78
+ ResEncoderBlock(
79
+ in_channels, out_channels, kernel_size, n_blocks, momentum=momentum
80
+ )
81
+ )
82
+ self.latent_channels.append([out_channels, in_size])
83
+ in_channels = out_channels
84
+ out_channels *= 2
85
+ in_size //= 2
86
+ self.out_size = in_size
87
+ self.out_channel = out_channels
88
+
89
+ def forward(self, x):
90
+ concat_tensors = []
91
+ x = self.bn(x)
92
+ for i in range(self.n_encoders):
93
+ _, x = self.layers[i](x)
94
+ concat_tensors.append(_)
95
+ return x, concat_tensors
96
+
97
+
98
+ class ResEncoderBlock(nn.Module):
99
+ def __init__(
100
+ self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01
101
+ ):
102
+ super(ResEncoderBlock, self).__init__()
103
+ self.n_blocks = n_blocks
104
+ self.conv = nn.ModuleList()
105
+ self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
106
+ for i in range(n_blocks - 1):
107
+ self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
108
+ self.kernel_size = kernel_size
109
+ if self.kernel_size is not None:
110
+ self.pool = nn.AvgPool2d(kernel_size=kernel_size)
111
+
112
+ def forward(self, x):
113
+ for i in range(self.n_blocks):
114
+ x = self.conv[i](x)
115
+ if self.kernel_size is not None:
116
+ return x, self.pool(x)
117
+ else:
118
+ return x
119
+
120
+
121
+ class Intermediate(nn.Module): #
122
+ def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
123
+ super(Intermediate, self).__init__()
124
+ self.n_inters = n_inters
125
+ self.layers = nn.ModuleList()
126
+ self.layers.append(
127
+ ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum)
128
+ )
129
+ for i in range(self.n_inters - 1):
130
+ self.layers.append(
131
+ ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum)
132
+ )
133
+
134
+ def forward(self, x):
135
+ for i in range(self.n_inters):
136
+ x = self.layers[i](x)
137
+ return x
138
+
139
+
140
+ class ResDecoderBlock(nn.Module):
141
+ def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
142
+ super(ResDecoderBlock, self).__init__()
143
+ out_padding = (0, 1) if stride == (1, 2) else (1, 1)
144
+ self.n_blocks = n_blocks
145
+ self.conv1 = nn.Sequential(
146
+ nn.ConvTranspose2d(
147
+ in_channels=in_channels,
148
+ out_channels=out_channels,
149
+ kernel_size=(3, 3),
150
+ stride=stride,
151
+ padding=(1, 1),
152
+ output_padding=out_padding,
153
+ bias=False,
154
+ ),
155
+ nn.BatchNorm2d(out_channels, momentum=momentum),
156
+ nn.ReLU(),
157
+ )
158
+ self.conv2 = nn.ModuleList()
159
+ self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
160
+ for i in range(n_blocks - 1):
161
+ self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
162
+
163
+ def forward(self, x, concat_tensor):
164
+ x = self.conv1(x)
165
+ x = torch.cat((x, concat_tensor), dim=1)
166
+ for i in range(self.n_blocks):
167
+ x = self.conv2[i](x)
168
+ return x
169
+
170
+
171
+ class Decoder(nn.Module):
172
+ def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
173
+ super(Decoder, self).__init__()
174
+ self.layers = nn.ModuleList()
175
+ self.n_decoders = n_decoders
176
+ for i in range(self.n_decoders):
177
+ out_channels = in_channels // 2
178
+ self.layers.append(
179
+ ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)
180
+ )
181
+ in_channels = out_channels
182
+
183
+ def forward(self, x, concat_tensors):
184
+ for i in range(self.n_decoders):
185
+ x = self.layers[i](x, concat_tensors[-1 - i])
186
+ return x
187
+
188
+
189
+ class DeepUnet(nn.Module):
190
+ def __init__(
191
+ self,
192
+ kernel_size,
193
+ n_blocks,
194
+ en_de_layers=5,
195
+ inter_layers=4,
196
+ in_channels=1,
197
+ en_out_channels=16,
198
+ ):
199
+ super(DeepUnet, self).__init__()
200
+ self.encoder = Encoder(
201
+ in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels
202
+ )
203
+ self.intermediate = Intermediate(
204
+ self.encoder.out_channel // 2,
205
+ self.encoder.out_channel,
206
+ inter_layers,
207
+ n_blocks,
208
+ )
209
+ self.decoder = Decoder(
210
+ self.encoder.out_channel, en_de_layers, kernel_size, n_blocks
211
+ )
212
+
213
+ def forward(self, x):
214
+ x, concat_tensors = self.encoder(x)
215
+ x = self.intermediate(x)
216
+ x = self.decoder(x, concat_tensors)
217
+ return x
218
+
219
+
220
+ class E2E(nn.Module):
221
+ def __init__(
222
+ self,
223
+ n_blocks,
224
+ n_gru,
225
+ kernel_size,
226
+ en_de_layers=5,
227
+ inter_layers=4,
228
+ in_channels=1,
229
+ en_out_channels=16,
230
+ ):
231
+ super(E2E, self).__init__()
232
+ self.unet = DeepUnet(
233
+ kernel_size,
234
+ n_blocks,
235
+ en_de_layers,
236
+ inter_layers,
237
+ in_channels,
238
+ en_out_channels,
239
+ )
240
+ self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
241
+ if n_gru:
242
+ self.fc = nn.Sequential(
243
+ BiGRU(3 * 128, 256, n_gru),
244
+ nn.Linear(512, 360),
245
+ nn.Dropout(0.25),
246
+ nn.Sigmoid(),
247
+ )
248
+ else:
249
+ self.fc = nn.Sequential(
250
+ nn.Linear(3 * nn.N_MELS, nn.N_CLASS), nn.Dropout(0.25), nn.Sigmoid()
251
+ )
252
+
253
+ def forward(self, mel):
254
+ mel = mel.transpose(-1, -2).unsqueeze(1)
255
+ x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
256
+ x = self.fc(x)
257
+ return x
258
+
259
+
260
+ from librosa.filters import mel
261
+
262
+
263
+ class MelSpectrogram(torch.nn.Module):
264
+ def __init__(
265
+ self,
266
+ is_half,
267
+ n_mel_channels,
268
+ sampling_rate,
269
+ win_length,
270
+ hop_length,
271
+ n_fft=None,
272
+ mel_fmin=0,
273
+ mel_fmax=None,
274
+ clamp=1e-5,
275
+ ):
276
+ super().__init__()
277
+ n_fft = win_length if n_fft is None else n_fft
278
+ self.hann_window = {}
279
+ mel_basis = mel(
280
+ sr=sampling_rate,
281
+ n_fft=n_fft,
282
+ n_mels=n_mel_channels,
283
+ fmin=mel_fmin,
284
+ fmax=mel_fmax,
285
+ htk=True,
286
+ )
287
+ mel_basis = torch.from_numpy(mel_basis).float()
288
+ self.register_buffer("mel_basis", mel_basis)
289
+ self.n_fft = win_length if n_fft is None else n_fft
290
+ self.hop_length = hop_length
291
+ self.win_length = win_length
292
+ self.sampling_rate = sampling_rate
293
+ self.n_mel_channels = n_mel_channels
294
+ self.clamp = clamp
295
+ self.is_half = is_half
296
+
297
+ def forward(self, audio, keyshift=0, speed=1, center=True):
298
+ factor = 2 ** (keyshift / 12)
299
+ n_fft_new = int(np.round(self.n_fft * factor))
300
+ win_length_new = int(np.round(self.win_length * factor))
301
+ hop_length_new = int(np.round(self.hop_length * speed))
302
+ keyshift_key = str(keyshift) + "_" + str(audio.device)
303
+ if keyshift_key not in self.hann_window:
304
+ self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(
305
+ audio.device
306
+ )
307
+ fft = torch.stft(
308
+ audio,
309
+ n_fft=n_fft_new,
310
+ hop_length=hop_length_new,
311
+ win_length=win_length_new,
312
+ window=self.hann_window[keyshift_key],
313
+ center=center,
314
+ return_complex=True,
315
+ )
316
+ magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
317
+ if keyshift != 0:
318
+ size = self.n_fft // 2 + 1
319
+ resize = magnitude.size(1)
320
+ if resize < size:
321
+ magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
322
+ magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
323
+ mel_output = torch.matmul(self.mel_basis, magnitude)
324
+ if self.is_half == True:
325
+ mel_output = mel_output.half()
326
+ log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
327
+ return log_mel_spec
328
+
329
+
330
+ class RMVPE:
331
+ def __init__(self, model_path, is_half, device=None):
332
+ self.resample_kernel = {}
333
+ model = E2E(4, 1, (2, 2))
334
+ ckpt = torch.load(model_path, map_location="cpu")
335
+ model.load_state_dict(ckpt)
336
+ model.eval()
337
+ if is_half == True:
338
+ model = model.half()
339
+ self.model = model
340
+ self.resample_kernel = {}
341
+ self.is_half = is_half
342
+ if device is None:
343
+ device = "cuda" if torch.cuda.is_available() else "cpu"
344
+ self.device = device
345
+ self.mel_extractor = MelSpectrogram(
346
+ is_half, 128, 16000, 1024, 160, None, 30, 8000
347
+ ).to(device)
348
+ self.model = self.model.to(device)
349
+ cents_mapping = 20 * np.arange(360) + 1997.3794084376191
350
+ self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368
351
+
352
+ def mel2hidden(self, mel):
353
+ with torch.no_grad():
354
+ n_frames = mel.shape[-1]
355
+ mel = F.pad(
356
+ mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="reflect"
357
+ )
358
+ hidden = self.model(mel)
359
+ return hidden[:, :n_frames]
360
+
361
+ def decode(self, hidden, thred=0.03):
362
+ cents_pred = self.to_local_average_cents(hidden, thred=thred)
363
+ f0 = 10 * (2 ** (cents_pred / 1200))
364
+ f0[f0 == 10] = 0
365
+ # f0 = np.array([10 * (2 ** (cent_pred / 1200)) if cent_pred else 0 for cent_pred in cents_pred])
366
+ return f0
367
+
368
+ def infer_from_audio(self, audio, thred=0.03):
369
+ audio = torch.from_numpy(audio).float().to(self.device).unsqueeze(0)
370
+ # torch.cuda.synchronize()
371
+ # t0=ttime()
372
+ mel = self.mel_extractor(audio, center=True)
373
+ # torch.cuda.synchronize()
374
+ # t1=ttime()
375
+ hidden = self.mel2hidden(mel)
376
+ # torch.cuda.synchronize()
377
+ # t2=ttime()
378
+ hidden = hidden.squeeze(0).cpu().numpy()
379
+ if self.is_half == True:
380
+ hidden = hidden.astype("float32")
381
+ f0 = self.decode(hidden, thred=thred)
382
+ # torch.cuda.synchronize()
383
+ # t3=ttime()
384
+ # print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0))
385
+ return f0
386
+
387
+ def pitch_based_audio_inference(self, audio, thred=0.03, f0_min=50, f0_max=1100):
388
+ audio = torch.from_numpy(audio).float().to(self.device).unsqueeze(0)
389
+ mel = self.mel_extractor(audio, center=True)
390
+ hidden = self.mel2hidden(mel)
391
+ hidden = hidden.squeeze(0).cpu().numpy()
392
+ if self.is_half == True:
393
+ hidden = hidden.astype("float32")
394
+ f0 = self.decode(hidden, thred=thred)
395
+ f0[(f0 < f0_min) | (f0 > f0_max)] = 0
396
+ return f0
397
+
398
+ def to_local_average_cents(self, salience, thred=0.05):
399
+ # t0 = ttime()
400
+ center = np.argmax(salience, axis=1) # frame length#index
401
+ salience = np.pad(salience, ((0, 0), (4, 4))) # frame length,368
402
+ # t1 = ttime()
403
+ center += 4
404
+ todo_salience = []
405
+ todo_cents_mapping = []
406
+ starts = center - 4
407
+ ends = center + 5
408
+ for idx in range(salience.shape[0]):
409
+ todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
410
+ todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
411
+ # t2 = ttime()
412
+ todo_salience = np.array(todo_salience) # frame length,9
413
+ todo_cents_mapping = np.array(todo_cents_mapping) # frame length,9
414
+ product_sum = np.sum(todo_salience * todo_cents_mapping, 1)
415
+ weight_sum = np.sum(todo_salience, 1) # frame length
416
+ devided = product_sum / weight_sum # frame length
417
+ # t3 = ttime()
418
+ maxx = np.max(salience, axis=1) # frame length
419
+ devided[maxx <= thred] = 0
420
+ # t4 = ttime()
421
+ # print("decode:%s\t%s\t%s\t%s" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
422
+ return devided
mdx_models/data.json ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "0ddfc0eb5792638ad5dc27850236c246": {
3
+ "compensate": 1.035,
4
+ "mdx_dim_f_set": 2048,
5
+ "mdx_dim_t_set": 8,
6
+ "mdx_n_fft_scale_set": 6144,
7
+ "primary_stem": "Vocals"
8
+ },
9
+ "26d308f91f3423a67dc69a6d12a8793d": {
10
+ "compensate": 1.035,
11
+ "mdx_dim_f_set": 2048,
12
+ "mdx_dim_t_set": 9,
13
+ "mdx_n_fft_scale_set": 8192,
14
+ "primary_stem": "Other"
15
+ },
16
+ "2cdd429caac38f0194b133884160f2c6": {
17
+ "compensate": 1.045,
18
+ "mdx_dim_f_set": 3072,
19
+ "mdx_dim_t_set": 8,
20
+ "mdx_n_fft_scale_set": 7680,
21
+ "primary_stem": "Instrumental"
22
+ },
23
+ "2f5501189a2f6db6349916fabe8c90de": {
24
+ "compensate": 1.035,
25
+ "mdx_dim_f_set": 2048,
26
+ "mdx_dim_t_set": 8,
27
+ "mdx_n_fft_scale_set": 6144,
28
+ "primary_stem": "Vocals"
29
+ },
30
+ "398580b6d5d973af3120df54cee6759d": {
31
+ "compensate": 1.75,
32
+ "mdx_dim_f_set": 3072,
33
+ "mdx_dim_t_set": 8,
34
+ "mdx_n_fft_scale_set": 7680,
35
+ "primary_stem": "Vocals"
36
+ },
37
+ "488b3e6f8bd3717d9d7c428476be2d75": {
38
+ "compensate": 1.035,
39
+ "mdx_dim_f_set": 3072,
40
+ "mdx_dim_t_set": 8,
41
+ "mdx_n_fft_scale_set": 7680,
42
+ "primary_stem": "Instrumental"
43
+ },
44
+ "4910e7827f335048bdac11fa967772f9": {
45
+ "compensate": 1.035,
46
+ "mdx_dim_f_set": 2048,
47
+ "mdx_dim_t_set": 7,
48
+ "mdx_n_fft_scale_set": 4096,
49
+ "primary_stem": "Drums"
50
+ },
51
+ "53c4baf4d12c3e6c3831bb8f5b532b93": {
52
+ "compensate": 1.043,
53
+ "mdx_dim_f_set": 3072,
54
+ "mdx_dim_t_set": 8,
55
+ "mdx_n_fft_scale_set": 7680,
56
+ "primary_stem": "Vocals"
57
+ },
58
+ "5d343409ef0df48c7d78cce9f0106781": {
59
+ "compensate": 1.075,
60
+ "mdx_dim_f_set": 3072,
61
+ "mdx_dim_t_set": 8,
62
+ "mdx_n_fft_scale_set": 7680,
63
+ "primary_stem": "Vocals"
64
+ },
65
+ "5f6483271e1efb9bfb59e4a3e6d4d098": {
66
+ "compensate": 1.035,
67
+ "mdx_dim_f_set": 2048,
68
+ "mdx_dim_t_set": 9,
69
+ "mdx_n_fft_scale_set": 6144,
70
+ "primary_stem": "Vocals"
71
+ },
72
+ "65ab5919372a128e4167f5e01a8fda85": {
73
+ "compensate": 1.035,
74
+ "mdx_dim_f_set": 2048,
75
+ "mdx_dim_t_set": 8,
76
+ "mdx_n_fft_scale_set": 8192,
77
+ "primary_stem": "Other"
78
+ },
79
+ "6703e39f36f18aa7855ee1047765621d": {
80
+ "compensate": 1.035,
81
+ "mdx_dim_f_set": 2048,
82
+ "mdx_dim_t_set": 9,
83
+ "mdx_n_fft_scale_set": 16384,
84
+ "primary_stem": "Bass"
85
+ },
86
+ "6b31de20e84392859a3d09d43f089515": {
87
+ "compensate": 1.035,
88
+ "mdx_dim_f_set": 2048,
89
+ "mdx_dim_t_set": 8,
90
+ "mdx_n_fft_scale_set": 6144,
91
+ "primary_stem": "Vocals"
92
+ },
93
+ "867595e9de46f6ab699008295df62798": {
94
+ "compensate": 1.03,
95
+ "mdx_dim_f_set": 3072,
96
+ "mdx_dim_t_set": 8,
97
+ "mdx_n_fft_scale_set": 7680,
98
+ "primary_stem": "Vocals"
99
+ },
100
+ "a3cd63058945e777505c01d2507daf37": {
101
+ "compensate": 1.03,
102
+ "mdx_dim_f_set": 2048,
103
+ "mdx_dim_t_set": 8,
104
+ "mdx_n_fft_scale_set": 6144,
105
+ "primary_stem": "Vocals"
106
+ },
107
+ "b33d9b3950b6cbf5fe90a32608924700": {
108
+ "compensate": 1.03,
109
+ "mdx_dim_f_set": 3072,
110
+ "mdx_dim_t_set": 8,
111
+ "mdx_n_fft_scale_set": 7680,
112
+ "primary_stem": "Vocals"
113
+ },
114
+ "c3b29bdce8c4fa17ec609e16220330ab": {
115
+ "compensate": 1.035,
116
+ "mdx_dim_f_set": 2048,
117
+ "mdx_dim_t_set": 8,
118
+ "mdx_n_fft_scale_set": 16384,
119
+ "primary_stem": "Bass"
120
+ },
121
+ "ceed671467c1f64ebdfac8a2490d0d52": {
122
+ "compensate": 1.035,
123
+ "mdx_dim_f_set": 3072,
124
+ "mdx_dim_t_set": 8,
125
+ "mdx_n_fft_scale_set": 7680,
126
+ "primary_stem": "Instrumental"
127
+ },
128
+ "d2a1376f310e4f7fa37fb9b5774eb701": {
129
+ "compensate": 1.035,
130
+ "mdx_dim_f_set": 3072,
131
+ "mdx_dim_t_set": 8,
132
+ "mdx_n_fft_scale_set": 7680,
133
+ "primary_stem": "Instrumental"
134
+ },
135
+ "d7bff498db9324db933d913388cba6be": {
136
+ "compensate": 1.035,
137
+ "mdx_dim_f_set": 2048,
138
+ "mdx_dim_t_set": 8,
139
+ "mdx_n_fft_scale_set": 6144,
140
+ "primary_stem": "Vocals"
141
+ },
142
+ "d94058f8c7f1fae4164868ae8ae66b20": {
143
+ "compensate": 1.035,
144
+ "mdx_dim_f_set": 2048,
145
+ "mdx_dim_t_set": 8,
146
+ "mdx_n_fft_scale_set": 6144,
147
+ "primary_stem": "Vocals"
148
+ },
149
+ "dc41ede5961d50f277eb846db17f5319": {
150
+ "compensate": 1.035,
151
+ "mdx_dim_f_set": 2048,
152
+ "mdx_dim_t_set": 9,
153
+ "mdx_n_fft_scale_set": 4096,
154
+ "primary_stem": "Drums"
155
+ },
156
+ "e5572e58abf111f80d8241d2e44e7fa4": {
157
+ "compensate": 1.028,
158
+ "mdx_dim_f_set": 3072,
159
+ "mdx_dim_t_set": 8,
160
+ "mdx_n_fft_scale_set": 7680,
161
+ "primary_stem": "Instrumental"
162
+ },
163
+ "e7324c873b1f615c35c1967f912db92a": {
164
+ "compensate": 1.03,
165
+ "mdx_dim_f_set": 3072,
166
+ "mdx_dim_t_set": 8,
167
+ "mdx_n_fft_scale_set": 7680,
168
+ "primary_stem": "Vocals"
169
+ },
170
+ "1c56ec0224f1d559c42fd6fd2a67b154": {
171
+ "compensate": 1.025,
172
+ "mdx_dim_f_set": 2048,
173
+ "mdx_dim_t_set": 8,
174
+ "mdx_n_fft_scale_set": 5120,
175
+ "primary_stem": "Instrumental"
176
+ },
177
+ "f2df6d6863d8f435436d8b561594ff49": {
178
+ "compensate": 1.035,
179
+ "mdx_dim_f_set": 3072,
180
+ "mdx_dim_t_set": 8,
181
+ "mdx_n_fft_scale_set": 7680,
182
+ "primary_stem": "Instrumental"
183
+ },
184
+ "b06327a00d5e5fbc7d96e1781bbdb596": {
185
+ "compensate": 1.035,
186
+ "mdx_dim_f_set": 3072,
187
+ "mdx_dim_t_set": 8,
188
+ "mdx_n_fft_scale_set": 6144,
189
+ "primary_stem": "Instrumental"
190
+ },
191
+ "94ff780b977d3ca07c7a343dab2e25dd": {
192
+ "compensate": 1.039,
193
+ "mdx_dim_f_set": 3072,
194
+ "mdx_dim_t_set": 8,
195
+ "mdx_n_fft_scale_set": 6144,
196
+ "primary_stem": "Instrumental"
197
+ },
198
+ "73492b58195c3b52d34590d5474452f6": {
199
+ "compensate": 1.043,
200
+ "mdx_dim_f_set": 3072,
201
+ "mdx_dim_t_set": 8,
202
+ "mdx_n_fft_scale_set": 7680,
203
+ "primary_stem": "Vocals"
204
+ },
205
+ "970b3f9492014d18fefeedfe4773cb42": {
206
+ "compensate": 1.009,
207
+ "mdx_dim_f_set": 3072,
208
+ "mdx_dim_t_set": 8,
209
+ "mdx_n_fft_scale_set": 7680,
210
+ "primary_stem": "Vocals"
211
+ },
212
+ "1d64a6d2c30f709b8c9b4ce1366d96ee": {
213
+ "compensate": 1.035,
214
+ "mdx_dim_f_set": 2048,
215
+ "mdx_dim_t_set": 8,
216
+ "mdx_n_fft_scale_set": 5120,
217
+ "primary_stem": "Instrumental"
218
+ },
219
+ "203f2a3955221b64df85a41af87cf8f0": {
220
+ "compensate": 1.035,
221
+ "mdx_dim_f_set": 3072,
222
+ "mdx_dim_t_set": 8,
223
+ "mdx_n_fft_scale_set": 6144,
224
+ "primary_stem": "Instrumental"
225
+ },
226
+ "291c2049608edb52648b96e27eb80e95": {
227
+ "compensate": 1.035,
228
+ "mdx_dim_f_set": 3072,
229
+ "mdx_dim_t_set": 8,
230
+ "mdx_n_fft_scale_set": 6144,
231
+ "primary_stem": "Instrumental"
232
+ },
233
+ "ead8d05dab12ec571d67549b3aab03fc": {
234
+ "compensate": 1.035,
235
+ "mdx_dim_f_set": 3072,
236
+ "mdx_dim_t_set": 8,
237
+ "mdx_n_fft_scale_set": 6144,
238
+ "primary_stem": "Instrumental"
239
+ },
240
+ "cc63408db3d80b4d85b0287d1d7c9632": {
241
+ "compensate": 1.033,
242
+ "mdx_dim_f_set": 3072,
243
+ "mdx_dim_t_set": 8,
244
+ "mdx_n_fft_scale_set": 6144,
245
+ "primary_stem": "Instrumental"
246
+ },
247
+ "cd5b2989ad863f116c855db1dfe24e39": {
248
+ "compensate": 1.035,
249
+ "mdx_dim_f_set": 3072,
250
+ "mdx_dim_t_set": 9,
251
+ "mdx_n_fft_scale_set": 6144,
252
+ "primary_stem": "Other"
253
+ },
254
+ "55657dd70583b0fedfba5f67df11d711": {
255
+ "compensate": 1.022,
256
+ "mdx_dim_f_set": 3072,
257
+ "mdx_dim_t_set": 8,
258
+ "mdx_n_fft_scale_set": 6144,
259
+ "primary_stem": "Instrumental"
260
+ },
261
+ "b6bccda408a436db8500083ef3491e8b": {
262
+ "compensate": 1.02,
263
+ "mdx_dim_f_set": 3072,
264
+ "mdx_dim_t_set": 8,
265
+ "mdx_n_fft_scale_set": 7680,
266
+ "primary_stem": "Instrumental"
267
+ },
268
+ "8a88db95c7fb5dbe6a095ff2ffb428b1": {
269
+ "compensate": 1.026,
270
+ "mdx_dim_f_set": 2048,
271
+ "mdx_dim_t_set": 8,
272
+ "mdx_n_fft_scale_set": 5120,
273
+ "primary_stem": "Instrumental"
274
+ },
275
+ "b78da4afc6512f98e4756f5977f5c6b9": {
276
+ "compensate": 1.021,
277
+ "mdx_dim_f_set": 3072,
278
+ "mdx_dim_t_set": 8,
279
+ "mdx_n_fft_scale_set": 7680,
280
+ "primary_stem": "Instrumental"
281
+ },
282
+ "77d07b2667ddf05b9e3175941b4454a0": {
283
+ "compensate": 1.021,
284
+ "mdx_dim_f_set": 3072,
285
+ "mdx_dim_t_set": 8,
286
+ "mdx_n_fft_scale_set": 7680,
287
+ "primary_stem": "Vocals"
288
+ },
289
+ "0f2a6bc5b49d87d64728ee40e23bceb1": {
290
+ "compensate": 1.019,
291
+ "mdx_dim_f_set": 2560,
292
+ "mdx_dim_t_set": 8,
293
+ "mdx_n_fft_scale_set": 5120,
294
+ "primary_stem": "Instrumental"
295
+ },
296
+ "b02be2d198d4968a121030cf8950b492": {
297
+ "compensate": 1.020,
298
+ "mdx_dim_f_set": 2560,
299
+ "mdx_dim_t_set": 8,
300
+ "mdx_n_fft_scale_set": 5120,
301
+ "primary_stem": "No Crowd"
302
+ },
303
+ "2154254ee89b2945b97a7efed6e88820": {
304
+ "config_yaml": "model_2_stem_061321.yaml"
305
+ },
306
+ "063aadd735d58150722926dcbf5852a9": {
307
+ "config_yaml": "model_2_stem_061321.yaml"
308
+ },
309
+ "fe96801369f6a148df2720f5ced88c19": {
310
+ "config_yaml": "model3.yaml"
311
+ },
312
+ "02e8b226f85fb566e5db894b9931c640": {
313
+ "config_yaml": "model2.yaml"
314
+ },
315
+ "e3de6d861635ab9c1d766149edd680d6": {
316
+ "config_yaml": "model1.yaml"
317
+ },
318
+ "3f2936c554ab73ce2e396d54636bd373": {
319
+ "config_yaml": "modelB.yaml"
320
+ },
321
+ "890d0f6f82d7574bca741a9e8bcb8168": {
322
+ "config_yaml": "modelB.yaml"
323
+ },
324
+ "63a3cb8c37c474681049be4ad1ba8815": {
325
+ "config_yaml": "modelB.yaml"
326
+ },
327
+ "a7fc5d719743c7fd6b61bd2b4d48b9f0": {
328
+ "config_yaml": "modelA.yaml"
329
+ },
330
+ "3567f3dee6e77bf366fcb1c7b8bc3745": {
331
+ "config_yaml": "modelA.yaml"
332
+ },
333
+ "a28f4d717bd0d34cd2ff7a3b0a3d065e": {
334
+ "config_yaml": "modelA.yaml"
335
+ },
336
+ "c9971a18da20911822593dc81caa8be9": {
337
+ "config_yaml": "sndfx.yaml"
338
+ },
339
+ "57d94d5ed705460d21c75a5ac829a605": {
340
+ "config_yaml": "sndfx.yaml"
341
+ },
342
+ "e7a25f8764f25a52c1b96c4946e66ba2": {
343
+ "config_yaml": "sndfx.yaml"
344
+ },
345
+ "104081d24e37217086ce5fde09147ee1": {
346
+ "config_yaml": "model_2_stem_061321.yaml"
347
+ },
348
+ "1e6165b601539f38d0a9330f3facffeb": {
349
+ "config_yaml": "model_2_stem_061321.yaml"
350
+ },
351
+ "fe0108464ce0d8271be5ab810891bd7c": {
352
+ "config_yaml": "model_2_stem_full_band.yaml"
353
+ }
354
+ }
packages.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ git-lfs
2
+ aria2 -y
3
+ ffmpeg
pre-requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu118
2
+ torch>=2.1.0+cu118
3
+ torchvision>=0.16.0+cu118
4
+ torchaudio>=2.1.0+cu118
5
+ yt-dlp
6
+ gradio==4.19.2
7
+ pydub==0.25.1
8
+ edge_tts==6.1.7
9
+ deep_translator==1.11.4
10
+ git+https://github.com/R3gm/pyannote-audio.git@3.1.1
11
+ git+https://github.com/R3gm/whisperX.git@cuda_11_8
12
+ nest_asyncio
13
+ gTTS
14
+ gradio_client==0.10.1
15
+ IPython
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ praat-parselmouth>=0.4.3
2
+ pyworld==0.3.2
3
+ faiss-cpu==1.7.3
4
+ torchcrepe==0.0.20
5
+ ffmpeg-python>=0.2.0
6
+ fairseq==0.12.2
7
+ gdown
8
+ rarfile
9
+ transformers
10
+ accelerate
11
+ optimum
12
+ sentencepiece
13
+ srt
14
+ git+https://github.com/R3gm/openvoice_package.git@lite
15
+ openai==1.14.3
16
+ tiktoken==0.6.0
17
+ # Documents
18
+ pypdf==4.2.0
19
+ python-docx
requirements_xtts.txt ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # core deps
2
+ numpy==1.23.5
3
+ cython>=0.29.30
4
+ scipy>=1.11.2
5
+ torch
6
+ torchaudio
7
+ soundfile
8
+ librosa
9
+ scikit-learn
10
+ numba
11
+ inflect>=5.6.0
12
+ tqdm>=4.64.1
13
+ anyascii>=0.3.0
14
+ pyyaml>=6.0
15
+ fsspec>=2023.6.0 # <= 2023.9.1 makes aux tests fail
16
+ aiohttp>=3.8.1
17
+ packaging>=23.1
18
+ # deps for examples
19
+ flask>=2.0.1
20
+ # deps for inference
21
+ pysbd>=0.3.4
22
+ # deps for notebooks
23
+ umap-learn>=0.5.1
24
+ pandas
25
+ # deps for training
26
+ matplotlib
27
+ # coqui stack
28
+ trainer>=0.0.32
29
+ # config management
30
+ coqpit>=0.0.16
31
+ # chinese g2p deps
32
+ jieba
33
+ pypinyin
34
+ # korean
35
+ hangul_romanize
36
+ # gruut+supported langs
37
+ gruut[de,es,fr]==2.2.3
38
+ # deps for korean
39
+ jamo
40
+ nltk
41
+ g2pkk>=0.1.1
42
+ # deps for bangla
43
+ bangla
44
+ bnnumerizer
45
+ bnunicodenormalizer
46
+ #deps for tortoise
47
+ einops>=0.6.0
48
+ transformers
49
+ #deps for bark
50
+ encodec>=0.1.1
51
+ # deps for XTTS
52
+ unidecode>=1.3.2
53
+ num2words
54
+ spacy[ja]>=3
55
+
56
+ # after this
57
+ # pip install -r requirements_xtts.txt
58
+ # pip install TTS==0.21.1 --no-deps
soni_translate/audio_segments.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydub import AudioSegment
2
+ from tqdm import tqdm
3
+ from .utils import run_command
4
+ from .logging_setup import logger
5
+ import numpy as np
6
+
7
+
8
+ class Mixer:
9
+ def __init__(self):
10
+ self.parts = []
11
+
12
+ def __len__(self):
13
+ parts = self._sync()
14
+ seg = parts[0][1]
15
+ frame_count = max(offset + seg.frame_count() for offset, seg in parts)
16
+ return int(1000.0 * frame_count / seg.frame_rate)
17
+
18
+ def overlay(self, sound, position=0):
19
+ self.parts.append((position, sound))
20
+ return self
21
+
22
+ def _sync(self):
23
+ positions, segs = zip(*self.parts)
24
+
25
+ frame_rate = segs[0].frame_rate
26
+ array_type = segs[0].array_type # noqa
27
+
28
+ offsets = [int(frame_rate * pos / 1000.0) for pos in positions]
29
+ segs = AudioSegment.empty()._sync(*segs)
30
+ return list(zip(offsets, segs))
31
+
32
+ def append(self, sound):
33
+ self.overlay(sound, position=len(self))
34
+
35
+ def to_audio_segment(self):
36
+ parts = self._sync()
37
+ seg = parts[0][1]
38
+ channels = seg.channels
39
+
40
+ frame_count = max(offset + seg.frame_count() for offset, seg in parts)
41
+ sample_count = int(frame_count * seg.channels)
42
+
43
+ output = np.zeros(sample_count, dtype="int32")
44
+ for offset, seg in parts:
45
+ sample_offset = offset * channels
46
+ samples = np.frombuffer(seg.get_array_of_samples(), dtype="int32")
47
+ samples = np.int16(samples/np.max(np.abs(samples)) * 32767)
48
+ start = sample_offset
49
+ end = start + len(samples)
50
+ output[start:end] += samples
51
+
52
+ return seg._spawn(
53
+ output, overrides={"sample_width": 4}).normalize(headroom=0.0)
54
+
55
+
56
+ def create_translated_audio(
57
+ result_diarize, audio_files, final_file, concat=False, avoid_overlap=False,
58
+ ):
59
+ total_duration = result_diarize["segments"][-1]["end"] # in seconds
60
+
61
+ if concat:
62
+ """
63
+ file .\audio\1.ogg
64
+ file .\audio\2.ogg
65
+ file .\audio\3.ogg
66
+ file .\audio\4.ogg
67
+ ...
68
+ """
69
+
70
+ # Write the file paths to list.txt
71
+ with open("list.txt", "w") as file:
72
+ for i, audio_file in enumerate(audio_files):
73
+ if i == len(audio_files) - 1: # Check if it's the last item
74
+ file.write(f"file {audio_file}")
75
+ else:
76
+ file.write(f"file {audio_file}\n")
77
+
78
+ # command = f"ffmpeg -f concat -safe 0 -i list.txt {final_file}"
79
+ command = (
80
+ f"ffmpeg -f concat -safe 0 -i list.txt -c:a pcm_s16le {final_file}"
81
+ )
82
+ run_command(command)
83
+
84
+ else:
85
+ # silent audio with total_duration
86
+ base_audio = AudioSegment.silent(
87
+ duration=int(total_duration * 1000), frame_rate=41000
88
+ )
89
+ combined_audio = Mixer()
90
+ combined_audio.overlay(base_audio)
91
+
92
+ logger.debug(
93
+ f"Audio duration: {total_duration // 60} "
94
+ f"minutes and {int(total_duration % 60)} seconds"
95
+ )
96
+
97
+ last_end_time = 0
98
+ previous_speaker = ""
99
+ for line, audio_file in tqdm(
100
+ zip(result_diarize["segments"], audio_files)
101
+ ):
102
+ start = float(line["start"])
103
+
104
+ # Overlay each audio at the corresponding time
105
+ try:
106
+ audio = AudioSegment.from_file(audio_file)
107
+ # audio_a = audio.speedup(playback_speed=1.5)
108
+
109
+ if avoid_overlap:
110
+ speaker = line["speaker"]
111
+ if (last_end_time - 0.500) > start:
112
+ overlap_time = last_end_time - start
113
+ if previous_speaker and previous_speaker != speaker:
114
+ start = (last_end_time - 0.500)
115
+ else:
116
+ start = (last_end_time - 0.200)
117
+ if overlap_time > 2.5:
118
+ start = start - 0.3
119
+ logger.info(
120
+ f"Avoid overlap for {str(audio_file)} "
121
+ f"with {str(start)}"
122
+ )
123
+
124
+ previous_speaker = speaker
125
+
126
+ duration_tts_seconds = len(audio) / 1000.0 # to sec
127
+ last_end_time = (start + duration_tts_seconds)
128
+
129
+ start_time = start * 1000 # to ms
130
+ combined_audio = combined_audio.overlay(
131
+ audio, position=start_time
132
+ )
133
+ except Exception as error:
134
+ logger.debug(str(error))
135
+ logger.error(f"Error audio file {audio_file}")
136
+
137
+ # combined audio as a file
138
+ combined_audio_data = combined_audio.to_audio_segment()
139
+ combined_audio_data.export(
140
+ final_file, format="wav"
141
+ ) # best than ogg, change if the audio is anomalous
soni_translate/language_configuration.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .logging_setup import logger
2
+
3
+ LANGUAGES_UNIDIRECTIONAL = {
4
+ "Aymara (ay)": "ay",
5
+ "Bambara (bm)": "bm",
6
+ "Cebuano (ceb)": "ceb",
7
+ "Chichewa (ny)": "ny",
8
+ "Divehi (dv)": "dv",
9
+ "Dogri (doi)": "doi",
10
+ "Ewe (ee)": "ee",
11
+ "Guarani (gn)": "gn",
12
+ "Iloko (ilo)": "ilo",
13
+ "Kinyarwanda (rw)": "rw",
14
+ "Krio (kri)": "kri",
15
+ "Kurdish (ku)": "ku",
16
+ "Kirghiz (ky)": "ky",
17
+ "Ganda (lg)": "lg",
18
+ "Maithili (mai)": "mai",
19
+ "Oriya (or)": "or",
20
+ "Oromo (om)": "om",
21
+ "Quechua (qu)": "qu",
22
+ "Samoan (sm)": "sm",
23
+ "Tigrinya (ti)": "ti",
24
+ "Tsonga (ts)": "ts",
25
+ "Akan (ak)": "ak",
26
+ "Uighur (ug)": "ug"
27
+ }
28
+
29
+ UNIDIRECTIONAL_L_LIST = LANGUAGES_UNIDIRECTIONAL.keys()
30
+
31
+ LANGUAGES = {
32
+ "Automatic detection": "Automatic detection",
33
+ "Arabic (ar)": "ar",
34
+ "Chinese - Simplified (zh-CN)": "zh",
35
+ "Czech (cs)": "cs",
36
+ "Danish (da)": "da",
37
+ "Dutch (nl)": "nl",
38
+ "English (en)": "en",
39
+ "Finnish (fi)": "fi",
40
+ "French (fr)": "fr",
41
+ "German (de)": "de",
42
+ "Greek (el)": "el",
43
+ "Hebrew (he)": "he",
44
+ "Hungarian (hu)": "hu",
45
+ "Italian (it)": "it",
46
+ "Japanese (ja)": "ja",
47
+ "Korean (ko)": "ko",
48
+ "Persian (fa)": "fa", # no aux gTTS
49
+ "Polish (pl)": "pl",
50
+ "Portuguese (pt)": "pt",
51
+ "Russian (ru)": "ru",
52
+ "Spanish (es)": "es",
53
+ "Turkish (tr)": "tr",
54
+ "Ukrainian (uk)": "uk",
55
+ "Urdu (ur)": "ur",
56
+ "Vietnamese (vi)": "vi",
57
+ "Hindi (hi)": "hi",
58
+ "Indonesian (id)": "id",
59
+ "Bengali (bn)": "bn",
60
+ "Telugu (te)": "te",
61
+ "Marathi (mr)": "mr",
62
+ "Tamil (ta)": "ta",
63
+ "Javanese (jw|jv)": "jw",
64
+ "Catalan (ca)": "ca",
65
+ "Nepali (ne)": "ne",
66
+ "Thai (th)": "th",
67
+ "Swedish (sv)": "sv",
68
+ "Amharic (am)": "am",
69
+ "Welsh (cy)": "cy", # no aux gTTS
70
+ "Estonian (et)": "et",
71
+ "Croatian (hr)": "hr",
72
+ "Icelandic (is)": "is",
73
+ "Georgian (ka)": "ka", # no aux gTTS
74
+ "Khmer (km)": "km",
75
+ "Slovak (sk)": "sk",
76
+ "Albanian (sq)": "sq",
77
+ "Serbian (sr)": "sr",
78
+ "Azerbaijani (az)": "az", # no aux gTTS
79
+ "Bulgarian (bg)": "bg",
80
+ "Galician (gl)": "gl", # no aux gTTS
81
+ "Gujarati (gu)": "gu",
82
+ "Kazakh (kk)": "kk", # no aux gTTS
83
+ "Kannada (kn)": "kn",
84
+ "Lithuanian (lt)": "lt", # no aux gTTS
85
+ "Latvian (lv)": "lv",
86
+ "Macedonian (mk)": "mk", # no aux gTTS # error get align model
87
+ "Malayalam (ml)": "ml",
88
+ "Malay (ms)": "ms", # error get align model
89
+ "Romanian (ro)": "ro",
90
+ "Sinhala (si)": "si",
91
+ "Sundanese (su)": "su",
92
+ "Swahili (sw)": "sw", # error aling
93
+ "Afrikaans (af)": "af",
94
+ "Bosnian (bs)": "bs",
95
+ "Latin (la)": "la",
96
+ "Myanmar Burmese (my)": "my",
97
+ "Norwegian (no|nb)": "no",
98
+ "Chinese - Traditional (zh-TW)": "zh-TW",
99
+ "Assamese (as)": "as",
100
+ "Basque (eu)": "eu",
101
+ "Hausa (ha)": "ha",
102
+ "Haitian Creole (ht)": "ht",
103
+ "Armenian (hy)": "hy",
104
+ "Lao (lo)": "lo",
105
+ "Malagasy (mg)": "mg",
106
+ "Mongolian (mn)": "mn",
107
+ "Maltese (mt)": "mt",
108
+ "Punjabi (pa)": "pa",
109
+ "Pashto (ps)": "ps",
110
+ "Slovenian (sl)": "sl",
111
+ "Shona (sn)": "sn",
112
+ "Somali (so)": "so",
113
+ "Tajik (tg)": "tg",
114
+ "Turkmen (tk)": "tk",
115
+ "Tatar (tt)": "tt",
116
+ "Uzbek (uz)": "uz",
117
+ "Yoruba (yo)": "yo",
118
+ **LANGUAGES_UNIDIRECTIONAL
119
+ }
120
+
121
+ BASE_L_LIST = LANGUAGES.keys()
122
+ LANGUAGES_LIST = [list(BASE_L_LIST)[0]] + sorted(list(BASE_L_LIST)[1:])
123
+ INVERTED_LANGUAGES = {value: key for key, value in LANGUAGES.items()}
124
+
125
+ EXTRA_ALIGN = {
126
+ "id": "indonesian-nlp/wav2vec2-large-xlsr-indonesian",
127
+ "bn": "arijitx/wav2vec2-large-xlsr-bengali",
128
+ "mr": "sumedh/wav2vec2-large-xlsr-marathi",
129
+ "ta": "Amrrs/wav2vec2-large-xlsr-53-tamil",
130
+ "jw": "cahya/wav2vec2-large-xlsr-javanese",
131
+ "ne": "shniranjan/wav2vec2-large-xlsr-300m-nepali",
132
+ "th": "sakares/wav2vec2-large-xlsr-thai-demo",
133
+ "sv": "KBLab/wav2vec2-large-voxrex-swedish",
134
+ "am": "agkphysics/wav2vec2-large-xlsr-53-amharic",
135
+ "cy": "Srulikbdd/Wav2Vec2-large-xlsr-welsh",
136
+ "et": "anton-l/wav2vec2-large-xlsr-53-estonian",
137
+ "hr": "classla/wav2vec2-xls-r-parlaspeech-hr",
138
+ "is": "carlosdanielhernandezmena/wav2vec2-large-xlsr-53-icelandic-ep10-1000h",
139
+ "ka": "MehdiHosseiniMoghadam/wav2vec2-large-xlsr-53-Georgian",
140
+ "km": "vitouphy/wav2vec2-xls-r-300m-khmer",
141
+ "sk": "infinitejoy/wav2vec2-large-xls-r-300m-slovak",
142
+ "sq": "Alimzhan/wav2vec2-large-xls-r-300m-albanian-colab",
143
+ "sr": "dnikolic/wav2vec2-xlsr-530-serbian-colab",
144
+ "az": "nijatzeynalov/wav2vec2-large-mms-1b-azerbaijani-common_voice15.0",
145
+ "bg": "infinitejoy/wav2vec2-large-xls-r-300m-bulgarian",
146
+ "gl": "ifrz/wav2vec2-large-xlsr-galician",
147
+ "gu": "Harveenchadha/vakyansh-wav2vec2-gujarati-gnm-100",
148
+ "kk": "aismlv/wav2vec2-large-xlsr-kazakh",
149
+ "kn": "Harveenchadha/vakyansh-wav2vec2-kannada-knm-560",
150
+ "lt": "DeividasM/wav2vec2-large-xlsr-53-lithuanian",
151
+ "lv": "anton-l/wav2vec2-large-xlsr-53-latvian",
152
+ "mk": "", # Konstantin-Bogdanoski/wav2vec2-macedonian-base
153
+ "ml": "gvs/wav2vec2-large-xlsr-malayalam",
154
+ "ms": "", # Duy/wav2vec2_malay
155
+ "ro": "anton-l/wav2vec2-large-xlsr-53-romanian",
156
+ "si": "IAmNotAnanth/wav2vec2-large-xls-r-300m-sinhala",
157
+ "su": "cahya/wav2vec2-large-xlsr-sundanese",
158
+ "sw": "", # Lians/fine-tune-wav2vec2-large-swahili
159
+ "af": "", # ylacombe/wav2vec2-common_voice-af-demo
160
+ "bs": "",
161
+ "la": "",
162
+ "my": "",
163
+ "no": "NbAiLab/wav2vec2-xlsr-300m-norwegian",
164
+ "zh-TW": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn",
165
+ "as": "",
166
+ "eu": "", # cahya/wav2vec2-large-xlsr-basque # verify
167
+ "ha": "infinitejoy/wav2vec2-large-xls-r-300m-hausa",
168
+ "ht": "",
169
+ "hy": "infinitejoy/wav2vec2-large-xls-r-300m-armenian", # no (.)
170
+ "lo": "",
171
+ "mg": "",
172
+ "mn": "tugstugi/wav2vec2-large-xlsr-53-mongolian",
173
+ "mt": "carlosdanielhernandezmena/wav2vec2-large-xlsr-53-maltese-64h",
174
+ "pa": "kingabzpro/wav2vec2-large-xlsr-53-punjabi",
175
+ "ps": "aamirhs/wav2vec2-large-xls-r-300m-pashto-colab",
176
+ "sl": "anton-l/wav2vec2-large-xlsr-53-slovenian",
177
+ "sn": "",
178
+ "so": "",
179
+ "tg": "",
180
+ "tk": "", # Ragav/wav2vec2-tk
181
+ "tt": "anton-l/wav2vec2-large-xlsr-53-tatar",
182
+ "uz": "", # Mekhriddin/wav2vec2-large-xls-r-300m-uzbek-colab
183
+ "yo": "ogbi/wav2vec2-large-mms-1b-yoruba-test",
184
+ }
185
+
186
+
187
+ def fix_code_language(translate_to, syntax="google"):
188
+ if syntax == "google":
189
+ # google-translator, gTTS
190
+ replace_lang_code = {"zh": "zh-CN", "he": "iw", "zh-cn": "zh-CN"}
191
+ elif syntax == "coqui":
192
+ # coqui-xtts
193
+ replace_lang_code = {"zh": "zh-cn", "zh-CN": "zh-cn", "zh-TW": "zh-cn"}
194
+
195
+ new_code_lang = replace_lang_code.get(translate_to, translate_to)
196
+ logger.debug(f"Fix code {translate_to} -> {new_code_lang}")
197
+ return new_code_lang
198
+
199
+
200
+ BARK_VOICES_LIST = {
201
+ "de_speaker_0-Male BARK": "v2/de_speaker_0",
202
+ "de_speaker_1-Male BARK": "v2/de_speaker_1",
203
+ "de_speaker_2-Male BARK": "v2/de_speaker_2",
204
+ "de_speaker_3-Female BARK": "v2/de_speaker_3",
205
+ "de_speaker_4-Male BARK": "v2/de_speaker_4",
206
+ "de_speaker_5-Male BARK": "v2/de_speaker_5",
207
+ "de_speaker_6-Male BARK": "v2/de_speaker_6",
208
+ "de_speaker_7-Male BARK": "v2/de_speaker_7",
209
+ "de_speaker_8-Female BARK": "v2/de_speaker_8",
210
+ "de_speaker_9-Male BARK": "v2/de_speaker_9",
211
+ "en_speaker_0-Male BARK": "v2/en_speaker_0",
212
+ "en_speaker_1-Male BARK": "v2/en_speaker_1",
213
+ "en_speaker_2-Male BARK": "v2/en_speaker_2",
214
+ "en_speaker_3-Male BARK": "v2/en_speaker_3",
215
+ "en_speaker_4-Male BARK": "v2/en_speaker_4",
216
+ "en_speaker_5-Male BARK": "v2/en_speaker_5",
217
+ "en_speaker_6-Male BARK": "v2/en_speaker_6",
218
+ "en_speaker_7-Male BARK": "v2/en_speaker_7",
219
+ "en_speaker_8-Male BARK": "v2/en_speaker_8",
220
+ "en_speaker_9-Female BARK": "v2/en_speaker_9",
221
+ "es_speaker_0-Male BARK": "v2/es_speaker_0",
222
+ "es_speaker_1-Male BARK": "v2/es_speaker_1",
223
+ "es_speaker_2-Male BARK": "v2/es_speaker_2",
224
+ "es_speaker_3-Male BARK": "v2/es_speaker_3",
225
+ "es_speaker_4-Male BARK": "v2/es_speaker_4",
226
+ "es_speaker_5-Male BARK": "v2/es_speaker_5",
227
+ "es_speaker_6-Male BARK": "v2/es_speaker_6",
228
+ "es_speaker_7-Male BARK": "v2/es_speaker_7",
229
+ "es_speaker_8-Female BARK": "v2/es_speaker_8",
230
+ "es_speaker_9-Female BARK": "v2/es_speaker_9",
231
+ "fr_speaker_0-Male BARK": "v2/fr_speaker_0",
232
+ "fr_speaker_1-Female BARK": "v2/fr_speaker_1",
233
+ "fr_speaker_2-Female BARK": "v2/fr_speaker_2",
234
+ "fr_speaker_3-Male BARK": "v2/fr_speaker_3",
235
+ "fr_speaker_4-Male BARK": "v2/fr_speaker_4",
236
+ "fr_speaker_5-Female BARK": "v2/fr_speaker_5",
237
+ "fr_speaker_6-Male BARK": "v2/fr_speaker_6",
238
+ "fr_speaker_7-Male BARK": "v2/fr_speaker_7",
239
+ "fr_speaker_8-Male BARK": "v2/fr_speaker_8",
240
+ "fr_speaker_9-Male BARK": "v2/fr_speaker_9",
241
+ "hi_speaker_0-Female BARK": "v2/hi_speaker_0",
242
+ "hi_speaker_1-Female BARK": "v2/hi_speaker_1",
243
+ "hi_speaker_2-Male BARK": "v2/hi_speaker_2",
244
+ "hi_speaker_3-Female BARK": "v2/hi_speaker_3",
245
+ "hi_speaker_4-Female BARK": "v2/hi_speaker_4",
246
+ "hi_speaker_5-Male BARK": "v2/hi_speaker_5",
247
+ "hi_speaker_6-Male BARK": "v2/hi_speaker_6",
248
+ "hi_speaker_7-Male BARK": "v2/hi_speaker_7",
249
+ "hi_speaker_8-Male BARK": "v2/hi_speaker_8",
250
+ "hi_speaker_9-Female BARK": "v2/hi_speaker_9",
251
+ "it_speaker_0-Male BARK": "v2/it_speaker_0",
252
+ "it_speaker_1-Male BARK": "v2/it_speaker_1",
253
+ "it_speaker_2-Female BARK": "v2/it_speaker_2",
254
+ "it_speaker_3-Male BARK": "v2/it_speaker_3",
255
+ "it_speaker_4-Male BARK": "v2/it_speaker_4",
256
+ "it_speaker_5-Male BARK": "v2/it_speaker_5",
257
+ "it_speaker_6-Male BARK": "v2/it_speaker_6",
258
+ "it_speaker_7-Female BARK": "v2/it_speaker_7",
259
+ "it_speaker_8-Male BARK": "v2/it_speaker_8",
260
+ "it_speaker_9-Female BARK": "v2/it_speaker_9",
261
+ "ja_speaker_0-Female BARK": "v2/ja_speaker_0",
262
+ "ja_speaker_1-Female BARK": "v2/ja_speaker_1",
263
+ "ja_speaker_2-Male BARK": "v2/ja_speaker_2",
264
+ "ja_speaker_3-Female BARK": "v2/ja_speaker_3",
265
+ "ja_speaker_4-Female BARK": "v2/ja_speaker_4",
266
+ "ja_speaker_5-Female BARK": "v2/ja_speaker_5",
267
+ "ja_speaker_6-Male BARK": "v2/ja_speaker_6",
268
+ "ja_speaker_7-Female BARK": "v2/ja_speaker_7",
269
+ "ja_speaker_8-Female BARK": "v2/ja_speaker_8",
270
+ "ja_speaker_9-Female BARK": "v2/ja_speaker_9",
271
+ "ko_speaker_0-Female BARK": "v2/ko_speaker_0",
272
+ "ko_speaker_1-Male BARK": "v2/ko_speaker_1",
273
+ "ko_speaker_2-Male BARK": "v2/ko_speaker_2",
274
+ "ko_speaker_3-Male BARK": "v2/ko_speaker_3",
275
+ "ko_speaker_4-Male BARK": "v2/ko_speaker_4",
276
+ "ko_speaker_5-Male BARK": "v2/ko_speaker_5",
277
+ "ko_speaker_6-Male BARK": "v2/ko_speaker_6",
278
+ "ko_speaker_7-Male BARK": "v2/ko_speaker_7",
279
+ "ko_speaker_8-Male BARK": "v2/ko_speaker_8",
280
+ "ko_speaker_9-Male BARK": "v2/ko_speaker_9",
281
+ "pl_speaker_0-Male BARK": "v2/pl_speaker_0",
282
+ "pl_speaker_1-Male BARK": "v2/pl_speaker_1",
283
+ "pl_speaker_2-Male BARK": "v2/pl_speaker_2",
284
+ "pl_speaker_3-Male BARK": "v2/pl_speaker_3",
285
+ "pl_speaker_4-Female BARK": "v2/pl_speaker_4",
286
+ "pl_speaker_5-Male BARK": "v2/pl_speaker_5",
287
+ "pl_speaker_6-Female BARK": "v2/pl_speaker_6",
288
+ "pl_speaker_7-Male BARK": "v2/pl_speaker_7",
289
+ "pl_speaker_8-Male BARK": "v2/pl_speaker_8",
290
+ "pl_speaker_9-Female BARK": "v2/pl_speaker_9",
291
+ "pt_speaker_0-Male BARK": "v2/pt_speaker_0",
292
+ "pt_speaker_1-Male BARK": "v2/pt_speaker_1",
293
+ "pt_speaker_2-Male BARK": "v2/pt_speaker_2",
294
+ "pt_speaker_3-Male BARK": "v2/pt_speaker_3",
295
+ "pt_speaker_4-Male BARK": "v2/pt_speaker_4",
296
+ "pt_speaker_5-Male BARK": "v2/pt_speaker_5",
297
+ "pt_speaker_6-Male BARK": "v2/pt_speaker_6",
298
+ "pt_speaker_7-Male BARK": "v2/pt_speaker_7",
299
+ "pt_speaker_8-Male BARK": "v2/pt_speaker_8",
300
+ "pt_speaker_9-Male BARK": "v2/pt_speaker_9",
301
+ "ru_speaker_0-Male BARK": "v2/ru_speaker_0",
302
+ "ru_speaker_1-Male BARK": "v2/ru_speaker_1",
303
+ "ru_speaker_2-Male BARK": "v2/ru_speaker_2",
304
+ "ru_speaker_3-Male BARK": "v2/ru_speaker_3",
305
+ "ru_speaker_4-Male BARK": "v2/ru_speaker_4",
306
+ "ru_speaker_5-Female BARK": "v2/ru_speaker_5",
307
+ "ru_speaker_6-Female BARK": "v2/ru_speaker_6",
308
+ "ru_speaker_7-Male BARK": "v2/ru_speaker_7",
309
+ "ru_speaker_8-Male BARK": "v2/ru_speaker_8",
310
+ "ru_speaker_9-Female BARK": "v2/ru_speaker_9",
311
+ "tr_speaker_0-Male BARK": "v2/tr_speaker_0",
312
+ "tr_speaker_1-Male BARK": "v2/tr_speaker_1",
313
+ "tr_speaker_2-Male BARK": "v2/tr_speaker_2",
314
+ "tr_speaker_3-Male BARK": "v2/tr_speaker_3",
315
+ "tr_speaker_4-Female BARK": "v2/tr_speaker_4",
316
+ "tr_speaker_5-Female BARK": "v2/tr_speaker_5",
317
+ "tr_speaker_6-Male BARK": "v2/tr_speaker_6",
318
+ "tr_speaker_7-Male BARK": "v2/tr_speaker_7",
319
+ "tr_speaker_8-Male BARK": "v2/tr_speaker_8",
320
+ "tr_speaker_9-Male BARK": "v2/tr_speaker_9",
321
+ "zh_speaker_0-Male BARK": "v2/zh_speaker_0",
322
+ "zh_speaker_1-Male BARK": "v2/zh_speaker_1",
323
+ "zh_speaker_2-Male BARK": "v2/zh_speaker_2",
324
+ "zh_speaker_3-Male BARK": "v2/zh_speaker_3",
325
+ "zh_speaker_4-Female BARK": "v2/zh_speaker_4",
326
+ "zh_speaker_5-Male BARK": "v2/zh_speaker_5",
327
+ "zh_speaker_6-Female BARK": "v2/zh_speaker_6",
328
+ "zh_speaker_7-Female BARK": "v2/zh_speaker_7",
329
+ "zh_speaker_8-Male BARK": "v2/zh_speaker_8",
330
+ "zh_speaker_9-Female BARK": "v2/zh_speaker_9",
331
+ }
332
+
333
+ VITS_VOICES_LIST = {
334
+ "ar-facebook-mms VITS": "facebook/mms-tts-ara",
335
+ # 'zh-facebook-mms VITS': 'facebook/mms-tts-cmn',
336
+ "zh_Hakka-facebook-mms VITS": "facebook/mms-tts-hak",
337
+ "zh_MinNan-facebook-mms VITS": "facebook/mms-tts-nan",
338
+ # 'cs-facebook-mms VITS': 'facebook/mms-tts-ces',
339
+ # 'da-facebook-mms VITS': 'facebook/mms-tts-dan',
340
+ "nl-facebook-mms VITS": "facebook/mms-tts-nld",
341
+ "en-facebook-mms VITS": "facebook/mms-tts-eng",
342
+ "fi-facebook-mms VITS": "facebook/mms-tts-fin",
343
+ "fr-facebook-mms VITS": "facebook/mms-tts-fra",
344
+ "de-facebook-mms VITS": "facebook/mms-tts-deu",
345
+ "el-facebook-mms VITS": "facebook/mms-tts-ell",
346
+ "el_Ancient-facebook-mms VITS": "facebook/mms-tts-grc",
347
+ "he-facebook-mms VITS": "facebook/mms-tts-heb",
348
+ "hu-facebook-mms VITS": "facebook/mms-tts-hun",
349
+ # 'it-facebook-mms VITS': 'facebook/mms-tts-ita',
350
+ # 'ja-facebook-mms VITS': 'facebook/mms-tts-jpn',
351
+ "ko-facebook-mms VITS": "facebook/mms-tts-kor",
352
+ "fa-facebook-mms VITS": "facebook/mms-tts-fas",
353
+ "pl-facebook-mms VITS": "facebook/mms-tts-pol",
354
+ "pt-facebook-mms VITS": "facebook/mms-tts-por",
355
+ "ru-facebook-mms VITS": "facebook/mms-tts-rus",
356
+ "es-facebook-mms VITS": "facebook/mms-tts-spa",
357
+ "tr-facebook-mms VITS": "facebook/mms-tts-tur",
358
+ "uk-facebook-mms VITS": "facebook/mms-tts-ukr",
359
+ "ur_arabic-facebook-mms VITS": "facebook/mms-tts-urd-script_arabic",
360
+ "ur_devanagari-facebook-mms VITS": "facebook/mms-tts-urd-script_devanagari",
361
+ "ur_latin-facebook-mms VITS": "facebook/mms-tts-urd-script_latin",
362
+ "vi-facebook-mms VITS": "facebook/mms-tts-vie",
363
+ "hi-facebook-mms VITS": "facebook/mms-tts-hin",
364
+ "hi_Fiji-facebook-mms VITS": "facebook/mms-tts-hif",
365
+ "id-facebook-mms VITS": "facebook/mms-tts-ind",
366
+ "bn-facebook-mms VITS": "facebook/mms-tts-ben",
367
+ "te-facebook-mms VITS": "facebook/mms-tts-tel",
368
+ "mr-facebook-mms VITS": "facebook/mms-tts-mar",
369
+ "ta-facebook-mms VITS": "facebook/mms-tts-tam",
370
+ "jw-facebook-mms VITS": "facebook/mms-tts-jav",
371
+ "jw_Suriname-facebook-mms VITS": "facebook/mms-tts-jvn",
372
+ "ca-facebook-mms VITS": "facebook/mms-tts-cat",
373
+ "ne-facebook-mms VITS": "facebook/mms-tts-nep",
374
+ "th-facebook-mms VITS": "facebook/mms-tts-tha",
375
+ "th_Northern-facebook-mms VITS": "facebook/mms-tts-nod",
376
+ "sv-facebook-mms VITS": "facebook/mms-tts-swe",
377
+ "am-facebook-mms VITS": "facebook/mms-tts-amh",
378
+ "cy-facebook-mms VITS": "facebook/mms-tts-cym",
379
+ # "et-facebook-mms VITS": "facebook/mms-tts-est",
380
+ # "ht-facebook-mms VITS": "facebook/mms-tts-hrv",
381
+ "is-facebook-mms VITS": "facebook/mms-tts-isl",
382
+ "km-facebook-mms VITS": "facebook/mms-tts-khm",
383
+ "km_Northern-facebook-mms VITS": "facebook/mms-tts-kxm",
384
+ # "sk-facebook-mms VITS": "facebook/mms-tts-slk",
385
+ "sq_Northern-facebook-mms VITS": "facebook/mms-tts-sqi",
386
+ "az_South-facebook-mms VITS": "facebook/mms-tts-azb",
387
+ "az_North_script_cyrillic-facebook-mms VITS": "facebook/mms-tts-azj-script_cyrillic",
388
+ "az_North_script_latin-facebook-mms VITS": "facebook/mms-tts-azj-script_latin",
389
+ "bg-facebook-mms VITS": "facebook/mms-tts-bul",
390
+ # "gl-facebook-mms VITS": "facebook/mms-tts-glg",
391
+ "gu-facebook-mms VITS": "facebook/mms-tts-guj",
392
+ "kk-facebook-mms VITS": "facebook/mms-tts-kaz",
393
+ "kn-facebook-mms VITS": "facebook/mms-tts-kan",
394
+ # "lt-facebook-mms VITS": "facebook/mms-tts-lit",
395
+ "lv-facebook-mms VITS": "facebook/mms-tts-lav",
396
+ # "mk-facebook-mms VITS": "facebook/mms-tts-mkd",
397
+ "ml-facebook-mms VITS": "facebook/mms-tts-mal",
398
+ "ms-facebook-mms VITS": "facebook/mms-tts-zlm",
399
+ "ms_Central-facebook-mms VITS": "facebook/mms-tts-pse",
400
+ "ms_Manado-facebook-mms VITS": "facebook/mms-tts-xmm",
401
+ "ro-facebook-mms VITS": "facebook/mms-tts-ron",
402
+ # "si-facebook-mms VITS": "facebook/mms-tts-sin",
403
+ "sw-facebook-mms VITS": "facebook/mms-tts-swh",
404
+ # "af-facebook-mms VITS": "facebook/mms-tts-afr",
405
+ # "bs-facebook-mms VITS": "facebook/mms-tts-bos",
406
+ "la-facebook-mms VITS": "facebook/mms-tts-lat",
407
+ "my-facebook-mms VITS": "facebook/mms-tts-mya",
408
+ # "no_Bokmål-facebook-mms VITS": "thomasht86/mms-tts-nob", # verify
409
+ "as-facebook-mms VITS": "facebook/mms-tts-asm",
410
+ "as_Nagamese-facebook-mms VITS": "facebook/mms-tts-nag",
411
+ "eu-facebook-mms VITS": "facebook/mms-tts-eus",
412
+ "ha-facebook-mms VITS": "facebook/mms-tts-hau",
413
+ "ht-facebook-mms VITS": "facebook/mms-tts-hat",
414
+ "hy_Western-facebook-mms VITS": "facebook/mms-tts-hyw",
415
+ "lo-facebook-mms VITS": "facebook/mms-tts-lao",
416
+ "mg-facebook-mms VITS": "facebook/mms-tts-mlg",
417
+ "mn-facebook-mms VITS": "facebook/mms-tts-mon",
418
+ # "mt-facebook-mms VITS": "facebook/mms-tts-mlt",
419
+ "pa_Eastern-facebook-mms VITS": "facebook/mms-tts-pan",
420
+ # "pa_Western-facebook-mms VITS": "facebook/mms-tts-pnb",
421
+ # "ps-facebook-mms VITS": "facebook/mms-tts-pus",
422
+ # "sl-facebook-mms VITS": "facebook/mms-tts-slv",
423
+ "sn-facebook-mms VITS": "facebook/mms-tts-sna",
424
+ "so-facebook-mms VITS": "facebook/mms-tts-son",
425
+ "tg-facebook-mms VITS": "facebook/mms-tts-tgk",
426
+ "tk_script_arabic-facebook-mms VITS": "facebook/mms-tts-tuk-script_arabic",
427
+ "tk_script_latin-facebook-mms VITS": "facebook/mms-tts-tuk-script_latin",
428
+ "tt-facebook-mms VITS": "facebook/mms-tts-tat",
429
+ "tt_Crimean-facebook-mms VITS": "facebook/mms-tts-crh",
430
+ "uz_script_cyrillic-facebook-mms VITS": "facebook/mms-tts-uzb-script_cyrillic",
431
+ "yo-facebook-mms VITS": "facebook/mms-tts-yor",
432
+ "ay-facebook-mms VITS": "facebook/mms-tts-ayr",
433
+ "bm-facebook-mms VITS": "facebook/mms-tts-bam",
434
+ "ceb-facebook-mms VITS": "facebook/mms-tts-ceb",
435
+ "ny-facebook-mms VITS": "facebook/mms-tts-nya",
436
+ "dv-facebook-mms VITS": "facebook/mms-tts-div",
437
+ "doi-facebook-mms VITS": "facebook/mms-tts-dgo",
438
+ "ee-facebook-mms VITS": "facebook/mms-tts-ewe",
439
+ "gn-facebook-mms VITS": "facebook/mms-tts-grn",
440
+ "ilo-facebook-mms VITS": "facebook/mms-tts-ilo",
441
+ "rw-facebook-mms VITS": "facebook/mms-tts-kin",
442
+ "kri-facebook-mms VITS": "facebook/mms-tts-kri",
443
+ "ku_script_arabic-facebook-mms VITS": "facebook/mms-tts-kmr-script_arabic",
444
+ "ku_script_cyrillic-facebook-mms VITS": "facebook/mms-tts-kmr-script_cyrillic",
445
+ "ku_script_latin-facebook-mms VITS": "facebook/mms-tts-kmr-script_latin",
446
+ "ckb-facebook-mms VITS": "razhan/mms-tts-ckb", # Verify w
447
+ "ky-facebook-mms VITS": "facebook/mms-tts-kir",
448
+ "lg-facebook-mms VITS": "facebook/mms-tts-lug",
449
+ "mai-facebook-mms VITS": "facebook/mms-tts-mai",
450
+ "or-facebook-mms VITS": "facebook/mms-tts-ory",
451
+ "om-facebook-mms VITS": "facebook/mms-tts-orm",
452
+ "qu_Huallaga-facebook-mms VITS": "facebook/mms-tts-qub",
453
+ "qu_Lambayeque-facebook-mms VITS": "facebook/mms-tts-quf",
454
+ "qu_South_Bolivian-facebook-mms VITS": "facebook/mms-tts-quh",
455
+ "qu_North_Bolivian-facebook-mms VITS": "facebook/mms-tts-qul",
456
+ "qu_Tena_Lowland-facebook-mms VITS": "facebook/mms-tts-quw",
457
+ "qu_Ayacucho-facebook-mms VITS": "facebook/mms-tts-quy",
458
+ "qu_Cusco-facebook-mms VITS": "facebook/mms-tts-quz",
459
+ "qu_Cajamarca-facebook-mms VITS": "facebook/mms-tts-qvc",
460
+ "qu_Eastern_Apurímac-facebook-mms VITS": "facebook/mms-tts-qve",
461
+ "qu_Huamalíes_Dos_de_Mayo_Huánuco-facebook-mms VITS": "facebook/mms-tts-qvh",
462
+ "qu_Margos_Yarowilca_Lauricocha-facebook-mms VITS": "facebook/mms-tts-qvm",
463
+ "qu_North_Junín-facebook-mms VITS": "facebook/mms-tts-qvn",
464
+ "qu_Napo-facebook-mms VITS": "facebook/mms-tts-qvo",
465
+ "qu_San_Martín-facebook-mms VITS": "facebook/mms-tts-qvs",
466
+ "qu_Huaylla_Wanca-facebook-mms VITS": "facebook/mms-tts-qvw",
467
+ "qu_Northern_Pastaza-facebook-mms VITS": "facebook/mms-tts-qvz",
468
+ "qu_Huaylas_Ancash-facebook-mms VITS": "facebook/mms-tts-qwh",
469
+ "qu_Panao-facebook-mms VITS": "facebook/mms-tts-qxh",
470
+ "qu_Salasaca_Highland-facebook-mms VITS": "facebook/mms-tts-qxl",
471
+ "qu_Northern_Conchucos_Ancash-facebook-mms VITS": "facebook/mms-tts-qxn",
472
+ "qu_Southern_Conchucos-facebook-mms VITS": "facebook/mms-tts-qxo",
473
+ "qu_Cañar_Highland-facebook-mms VITS": "facebook/mms-tts-qxr",
474
+ "sm-facebook-mms VITS": "facebook/mms-tts-smo",
475
+ "ti-facebook-mms VITS": "facebook/mms-tts-tir",
476
+ "ts-facebook-mms VITS": "facebook/mms-tts-tso",
477
+ "ak-facebook-mms VITS": "facebook/mms-tts-aka",
478
+ "ug_script_arabic-facebook-mms VITS": "facebook/mms-tts-uig-script_arabic",
479
+ "ug_script_cyrillic-facebook-mms VITS": "facebook/mms-tts-uig-script_cyrillic",
480
+ }
481
+
482
+ OPENAI_TTS_CODES = [
483
+ "af", "ar", "hy", "az", "be", "bs", "bg", "ca", "zh", "hr", "cs", "da",
484
+ "nl", "en", "et", "fi", "fr", "gl", "de", "el", "he", "hi", "hu", "is",
485
+ "id", "it", "ja", "kn", "kk", "ko", "lv", "lt", "mk", "ms", "mr", "mi",
486
+ "ne", "no", "fa", "pl", "pt", "ro", "ru", "sr", "sk", "sl", "es", "sw",
487
+ "sv", "tl", "ta", "th", "tr", "uk", "ur", "vi", "cy", "zh-TW"
488
+ ]
489
+
490
+ OPENAI_TTS_MODELS = [
491
+ ">alloy OpenAI-TTS",
492
+ ">echo OpenAI-TTS",
493
+ ">fable OpenAI-TTS",
494
+ ">onyx OpenAI-TTS",
495
+ ">nova OpenAI-TTS",
496
+ ">shimmer OpenAI-TTS",
497
+ ">alloy HD OpenAI-TTS",
498
+ ">echo HD OpenAI-TTS",
499
+ ">fable HD OpenAI-TTS",
500
+ ">onyx HD OpenAI-TTS",
501
+ ">nova HD OpenAI-TTS",
502
+ ">shimmer HD OpenAI-TTS"
503
+ ]
504
+
505
+ LANGUAGE_CODE_IN_THREE_LETTERS = {
506
+ "Automatic detection": "aut",
507
+ "ar": "ara",
508
+ "zh": "chi",
509
+ "cs": "cze",
510
+ "da": "dan",
511
+ "nl": "dut",
512
+ "en": "eng",
513
+ "fi": "fin",
514
+ "fr": "fre",
515
+ "de": "ger",
516
+ "el": "gre",
517
+ "he": "heb",
518
+ "hu": "hun",
519
+ "it": "ita",
520
+ "ja": "jpn",
521
+ "ko": "kor",
522
+ "fa": "per",
523
+ "pl": "pol",
524
+ "pt": "por",
525
+ "ru": "rus",
526
+ "es": "spa",
527
+ "tr": "tur",
528
+ "uk": "ukr",
529
+ "ur": "urd",
530
+ "vi": "vie",
531
+ "hi": "hin",
532
+ "id": "ind",
533
+ "bn": "ben",
534
+ "te": "tel",
535
+ "mr": "mar",
536
+ "ta": "tam",
537
+ "jw": "jav",
538
+ "ca": "cat",
539
+ "ne": "nep",
540
+ "th": "tha",
541
+ "sv": "swe",
542
+ "am": "amh",
543
+ "cy": "cym",
544
+ "et": "est",
545
+ "hr": "hrv",
546
+ "is": "isl",
547
+ "km": "khm",
548
+ "sk": "slk",
549
+ "sq": "sqi",
550
+ "sr": "srp",
551
+ }
soni_translate/languages_gui.py ADDED
The diff for this file is too large to render. See raw diff
 
soni_translate/logging_setup.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import sys
3
+ import warnings
4
+ import os
5
+
6
+
7
+ def configure_logging_libs(debug=False):
8
+ warnings.filterwarnings(
9
+ action="ignore", category=UserWarning, module="pyannote"
10
+ )
11
+ modules = [
12
+ "numba", "httpx", "markdown_it", "speechbrain", "fairseq", "pyannote",
13
+ "faiss",
14
+ "pytorch_lightning.utilities.migration.utils",
15
+ "pytorch_lightning.utilities.migration",
16
+ "pytorch_lightning",
17
+ "lightning",
18
+ "lightning.pytorch.utilities.migration.utils",
19
+ ]
20
+ try:
21
+ for module in modules:
22
+ logging.getLogger(module).setLevel(logging.WARNING)
23
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = "3" if not debug else "1"
24
+
25
+ # fix verbose pyannote audio
26
+ def fix_verbose_pyannote(*args, what=""):
27
+ pass
28
+ import pyannote.audio.core.model # noqa
29
+ pyannote.audio.core.model.check_version = fix_verbose_pyannote
30
+ except Exception as error:
31
+ logger.error(str(error))
32
+
33
+
34
+ def setup_logger(name_log):
35
+ logger = logging.getLogger(name_log)
36
+ logger.setLevel(logging.INFO)
37
+
38
+ _default_handler = logging.StreamHandler() # Set sys.stderr as stream.
39
+ _default_handler.flush = sys.stderr.flush
40
+ logger.addHandler(_default_handler)
41
+
42
+ logger.propagate = False
43
+
44
+ handlers = logger.handlers
45
+
46
+ for handler in handlers:
47
+ formatter = logging.Formatter("[%(levelname)s] >> %(message)s")
48
+ handler.setFormatter(formatter)
49
+
50
+ # logger.handlers
51
+
52
+ return logger
53
+
54
+
55
+ logger = setup_logger("sonitranslate")
56
+ logger.setLevel(logging.INFO)
57
+
58
+
59
+ def set_logging_level(verbosity_level):
60
+ logging_level_mapping = {
61
+ "debug": logging.DEBUG,
62
+ "info": logging.INFO,
63
+ "warning": logging.WARNING,
64
+ "error": logging.ERROR,
65
+ "critical": logging.CRITICAL,
66
+ }
67
+
68
+ logger.setLevel(logging_level_mapping.get(verbosity_level, logging.INFO))
soni_translate/mdx_net.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import hashlib
3
+ import os
4
+ import queue
5
+ import threading
6
+ import json
7
+ import shlex
8
+ import sys
9
+ import subprocess
10
+ import librosa
11
+ import numpy as np
12
+ import soundfile as sf
13
+ import torch
14
+ from tqdm import tqdm
15
+
16
+ try:
17
+ from .utils import (
18
+ remove_directory_contents,
19
+ create_directories,
20
+ )
21
+ except: # noqa
22
+ from utils import (
23
+ remove_directory_contents,
24
+ create_directories,
25
+ )
26
+ from .logging_setup import logger
27
+
28
+ try:
29
+ import onnxruntime as ort
30
+ except Exception as error:
31
+ logger.error(str(error))
32
+ # import warnings
33
+ # warnings.filterwarnings("ignore")
34
+
35
+ stem_naming = {
36
+ "Vocals": "Instrumental",
37
+ "Other": "Instruments",
38
+ "Instrumental": "Vocals",
39
+ "Drums": "Drumless",
40
+ "Bass": "Bassless",
41
+ }
42
+
43
+
44
+ class MDXModel:
45
+ def __init__(
46
+ self,
47
+ device,
48
+ dim_f,
49
+ dim_t,
50
+ n_fft,
51
+ hop=1024,
52
+ stem_name=None,
53
+ compensation=1.000,
54
+ ):
55
+ self.dim_f = dim_f
56
+ self.dim_t = dim_t
57
+ self.dim_c = 4
58
+ self.n_fft = n_fft
59
+ self.hop = hop
60
+ self.stem_name = stem_name
61
+ self.compensation = compensation
62
+
63
+ self.n_bins = self.n_fft // 2 + 1
64
+ self.chunk_size = hop * (self.dim_t - 1)
65
+ self.window = torch.hann_window(
66
+ window_length=self.n_fft, periodic=True
67
+ ).to(device)
68
+
69
+ out_c = self.dim_c
70
+
71
+ self.freq_pad = torch.zeros(
72
+ [1, out_c, self.n_bins - self.dim_f, self.dim_t]
73
+ ).to(device)
74
+
75
+ def stft(self, x):
76
+ x = x.reshape([-1, self.chunk_size])
77
+ x = torch.stft(
78
+ x,
79
+ n_fft=self.n_fft,
80
+ hop_length=self.hop,
81
+ window=self.window,
82
+ center=True,
83
+ return_complex=True,
84
+ )
85
+ x = torch.view_as_real(x)
86
+ x = x.permute([0, 3, 1, 2])
87
+ x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape(
88
+ [-1, 4, self.n_bins, self.dim_t]
89
+ )
90
+ return x[:, :, : self.dim_f]
91
+
92
+ def istft(self, x, freq_pad=None):
93
+ freq_pad = (
94
+ self.freq_pad.repeat([x.shape[0], 1, 1, 1])
95
+ if freq_pad is None
96
+ else freq_pad
97
+ )
98
+ x = torch.cat([x, freq_pad], -2)
99
+ # c = 4*2 if self.target_name=='*' else 2
100
+ x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape(
101
+ [-1, 2, self.n_bins, self.dim_t]
102
+ )
103
+ x = x.permute([0, 2, 3, 1])
104
+ x = x.contiguous()
105
+ x = torch.view_as_complex(x)
106
+ x = torch.istft(
107
+ x,
108
+ n_fft=self.n_fft,
109
+ hop_length=self.hop,
110
+ window=self.window,
111
+ center=True,
112
+ )
113
+ return x.reshape([-1, 2, self.chunk_size])
114
+
115
+
116
+ class MDX:
117
+ DEFAULT_SR = 44100
118
+ # Unit: seconds
119
+ DEFAULT_CHUNK_SIZE = 0 * DEFAULT_SR
120
+ DEFAULT_MARGIN_SIZE = 1 * DEFAULT_SR
121
+
122
+ def __init__(
123
+ self, model_path: str, params: MDXModel, processor=0
124
+ ):
125
+ # Set the device and the provider (CPU or CUDA)
126
+ self.device = (
127
+ torch.device(f"cuda:{processor}")
128
+ if processor >= 0
129
+ else torch.device("cpu")
130
+ )
131
+ self.provider = (
132
+ ["CUDAExecutionProvider"]
133
+ if processor >= 0
134
+ else ["CPUExecutionProvider"]
135
+ )
136
+
137
+ self.model = params
138
+
139
+ # Load the ONNX model using ONNX Runtime
140
+ self.ort = ort.InferenceSession(model_path, providers=self.provider)
141
+ # Preload the model for faster performance
142
+ self.ort.run(
143
+ None,
144
+ {"input": torch.rand(1, 4, params.dim_f, params.dim_t).numpy()},
145
+ )
146
+ self.process = lambda spec: self.ort.run(
147
+ None, {"input": spec.cpu().numpy()}
148
+ )[0]
149
+
150
+ self.prog = None
151
+
152
+ @staticmethod
153
+ def get_hash(model_path):
154
+ try:
155
+ with open(model_path, "rb") as f:
156
+ f.seek(-10000 * 1024, 2)
157
+ model_hash = hashlib.md5(f.read()).hexdigest()
158
+ except: # noqa
159
+ model_hash = hashlib.md5(open(model_path, "rb").read()).hexdigest()
160
+
161
+ return model_hash
162
+
163
+ @staticmethod
164
+ def segment(
165
+ wave,
166
+ combine=True,
167
+ chunk_size=DEFAULT_CHUNK_SIZE,
168
+ margin_size=DEFAULT_MARGIN_SIZE,
169
+ ):
170
+ """
171
+ Segment or join segmented wave array
172
+
173
+ Args:
174
+ wave: (np.array) Wave array to be segmented or joined
175
+ combine: (bool) If True, combines segmented wave array.
176
+ If False, segments wave array.
177
+ chunk_size: (int) Size of each segment (in samples)
178
+ margin_size: (int) Size of margin between segments (in samples)
179
+
180
+ Returns:
181
+ numpy array: Segmented or joined wave array
182
+ """
183
+
184
+ if combine:
185
+ # Initializing as None instead of [] for later numpy array concatenation
186
+ processed_wave = None
187
+ for segment_count, segment in enumerate(wave):
188
+ start = 0 if segment_count == 0 else margin_size
189
+ end = None if segment_count == len(wave) - 1 else -margin_size
190
+ if margin_size == 0:
191
+ end = None
192
+ if processed_wave is None: # Create array for first segment
193
+ processed_wave = segment[:, start:end]
194
+ else: # Concatenate to existing array for subsequent segments
195
+ processed_wave = np.concatenate(
196
+ (processed_wave, segment[:, start:end]), axis=-1
197
+ )
198
+
199
+ else:
200
+ processed_wave = []
201
+ sample_count = wave.shape[-1]
202
+
203
+ if chunk_size <= 0 or chunk_size > sample_count:
204
+ chunk_size = sample_count
205
+
206
+ if margin_size > chunk_size:
207
+ margin_size = chunk_size
208
+
209
+ for segment_count, skip in enumerate(
210
+ range(0, sample_count, chunk_size)
211
+ ):
212
+ margin = 0 if segment_count == 0 else margin_size
213
+ end = min(skip + chunk_size + margin_size, sample_count)
214
+ start = skip - margin
215
+
216
+ cut = wave[:, start:end].copy()
217
+ processed_wave.append(cut)
218
+
219
+ if end == sample_count:
220
+ break
221
+
222
+ return processed_wave
223
+
224
+ def pad_wave(self, wave):
225
+ """
226
+ Pad the wave array to match the required chunk size
227
+
228
+ Args:
229
+ wave: (np.array) Wave array to be padded
230
+
231
+ Returns:
232
+ tuple: (padded_wave, pad, trim)
233
+ - padded_wave: Padded wave array
234
+ - pad: Number of samples that were padded
235
+ - trim: Number of samples that were trimmed
236
+ """
237
+ n_sample = wave.shape[1]
238
+ trim = self.model.n_fft // 2
239
+ gen_size = self.model.chunk_size - 2 * trim
240
+ pad = gen_size - n_sample % gen_size
241
+
242
+ # Padded wave
243
+ wave_p = np.concatenate(
244
+ (
245
+ np.zeros((2, trim)),
246
+ wave,
247
+ np.zeros((2, pad)),
248
+ np.zeros((2, trim)),
249
+ ),
250
+ 1,
251
+ )
252
+
253
+ mix_waves = []
254
+ for i in range(0, n_sample + pad, gen_size):
255
+ waves = np.array(wave_p[:, i:i + self.model.chunk_size])
256
+ mix_waves.append(waves)
257
+
258
+ mix_waves = torch.tensor(mix_waves, dtype=torch.float32).to(
259
+ self.device
260
+ )
261
+
262
+ return mix_waves, pad, trim
263
+
264
+ def _process_wave(self, mix_waves, trim, pad, q: queue.Queue, _id: int):
265
+ """
266
+ Process each wave segment in a multi-threaded environment
267
+
268
+ Args:
269
+ mix_waves: (torch.Tensor) Wave segments to be processed
270
+ trim: (int) Number of samples trimmed during padding
271
+ pad: (int) Number of samples padded during padding
272
+ q: (queue.Queue) Queue to hold the processed wave segments
273
+ _id: (int) Identifier of the processed wave segment
274
+
275
+ Returns:
276
+ numpy array: Processed wave segment
277
+ """
278
+ mix_waves = mix_waves.split(1)
279
+ with torch.no_grad():
280
+ pw = []
281
+ for mix_wave in mix_waves:
282
+ self.prog.update()
283
+ spec = self.model.stft(mix_wave)
284
+ processed_spec = torch.tensor(self.process(spec))
285
+ processed_wav = self.model.istft(
286
+ processed_spec.to(self.device)
287
+ )
288
+ processed_wav = (
289
+ processed_wav[:, :, trim:-trim]
290
+ .transpose(0, 1)
291
+ .reshape(2, -1)
292
+ .cpu()
293
+ .numpy()
294
+ )
295
+ pw.append(processed_wav)
296
+ processed_signal = np.concatenate(pw, axis=-1)[:, :-pad]
297
+ q.put({_id: processed_signal})
298
+ return processed_signal
299
+
300
+ def process_wave(self, wave: np.array, mt_threads=1):
301
+ """
302
+ Process the wave array in a multi-threaded environment
303
+
304
+ Args:
305
+ wave: (np.array) Wave array to be processed
306
+ mt_threads: (int) Number of threads to be used for processing
307
+
308
+ Returns:
309
+ numpy array: Processed wave array
310
+ """
311
+ self.prog = tqdm(total=0)
312
+ chunk = wave.shape[-1] // mt_threads
313
+ waves = self.segment(wave, False, chunk)
314
+
315
+ # Create a queue to hold the processed wave segments
316
+ q = queue.Queue()
317
+ threads = []
318
+ for c, batch in enumerate(waves):
319
+ mix_waves, pad, trim = self.pad_wave(batch)
320
+ self.prog.total = len(mix_waves) * mt_threads
321
+ thread = threading.Thread(
322
+ target=self._process_wave, args=(mix_waves, trim, pad, q, c)
323
+ )
324
+ thread.start()
325
+ threads.append(thread)
326
+ for thread in threads:
327
+ thread.join()
328
+ self.prog.close()
329
+
330
+ processed_batches = []
331
+ while not q.empty():
332
+ processed_batches.append(q.get())
333
+ processed_batches = [
334
+ list(wave.values())[0]
335
+ for wave in sorted(
336
+ processed_batches, key=lambda d: list(d.keys())[0]
337
+ )
338
+ ]
339
+ assert len(processed_batches) == len(
340
+ waves
341
+ ), "Incomplete processed batches, please reduce batch size!"
342
+ return self.segment(processed_batches, True, chunk)
343
+
344
+
345
+ def run_mdx(
346
+ model_params,
347
+ output_dir,
348
+ model_path,
349
+ filename,
350
+ exclude_main=False,
351
+ exclude_inversion=False,
352
+ suffix=None,
353
+ invert_suffix=None,
354
+ denoise=False,
355
+ keep_orig=True,
356
+ m_threads=2,
357
+ device_base="cuda",
358
+ ):
359
+ if device_base == "cuda":
360
+ device = torch.device("cuda:0")
361
+ processor_num = 0
362
+ device_properties = torch.cuda.get_device_properties(device)
363
+ vram_gb = device_properties.total_memory / 1024**3
364
+ m_threads = 1 if vram_gb < 8 else 2
365
+ else:
366
+ device = torch.device("cpu")
367
+ processor_num = -1
368
+ m_threads = 1
369
+
370
+ model_hash = MDX.get_hash(model_path)
371
+ mp = model_params.get(model_hash)
372
+ model = MDXModel(
373
+ device,
374
+ dim_f=mp["mdx_dim_f_set"],
375
+ dim_t=2 ** mp["mdx_dim_t_set"],
376
+ n_fft=mp["mdx_n_fft_scale_set"],
377
+ stem_name=mp["primary_stem"],
378
+ compensation=mp["compensate"],
379
+ )
380
+
381
+ mdx_sess = MDX(model_path, model, processor=processor_num)
382
+ wave, sr = librosa.load(filename, mono=False, sr=44100)
383
+ # normalizing input wave gives better output
384
+ peak = max(np.max(wave), abs(np.min(wave)))
385
+ wave /= peak
386
+ if denoise:
387
+ wave_processed = -(mdx_sess.process_wave(-wave, m_threads)) + (
388
+ mdx_sess.process_wave(wave, m_threads)
389
+ )
390
+ wave_processed *= 0.5
391
+ else:
392
+ wave_processed = mdx_sess.process_wave(wave, m_threads)
393
+ # return to previous peak
394
+ wave_processed *= peak
395
+ stem_name = model.stem_name if suffix is None else suffix
396
+
397
+ main_filepath = None
398
+ if not exclude_main:
399
+ main_filepath = os.path.join(
400
+ output_dir,
401
+ f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav",
402
+ )
403
+ sf.write(main_filepath, wave_processed.T, sr)
404
+
405
+ invert_filepath = None
406
+ if not exclude_inversion:
407
+ diff_stem_name = (
408
+ stem_naming.get(stem_name)
409
+ if invert_suffix is None
410
+ else invert_suffix
411
+ )
412
+ stem_name = (
413
+ f"{stem_name}_diff" if diff_stem_name is None else diff_stem_name
414
+ )
415
+ invert_filepath = os.path.join(
416
+ output_dir,
417
+ f"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav",
418
+ )
419
+ sf.write(
420
+ invert_filepath,
421
+ (-wave_processed.T * model.compensation) + wave.T,
422
+ sr,
423
+ )
424
+
425
+ if not keep_orig:
426
+ os.remove(filename)
427
+
428
+ del mdx_sess, wave_processed, wave
429
+ gc.collect()
430
+ torch.cuda.empty_cache()
431
+ return main_filepath, invert_filepath
432
+
433
+
434
+ MDX_DOWNLOAD_LINK = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/"
435
+ UVR_MODELS = [
436
+ "UVR-MDX-NET-Voc_FT.onnx",
437
+ "UVR_MDXNET_KARA_2.onnx",
438
+ "Reverb_HQ_By_FoxJoy.onnx",
439
+ "UVR-MDX-NET-Inst_HQ_4.onnx",
440
+ ]
441
+ BASE_DIR = "." # os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
442
+ mdxnet_models_dir = os.path.join(BASE_DIR, "mdx_models")
443
+ output_dir = os.path.join(BASE_DIR, "clean_song_output")
444
+
445
+
446
+ def convert_to_stereo_and_wav(audio_path):
447
+ wave, sr = librosa.load(audio_path, mono=False, sr=44100)
448
+
449
+ # check if mono
450
+ if type(wave[0]) != np.ndarray or audio_path[-4:].lower() != ".wav": # noqa
451
+ stereo_path = f"{os.path.splitext(audio_path)[0]}_stereo.wav"
452
+ stereo_path = os.path.join(output_dir, stereo_path)
453
+
454
+ command = shlex.split(
455
+ f'ffmpeg -y -loglevel error -i "{audio_path}" -ac 2 -f wav "{stereo_path}"'
456
+ )
457
+ sub_params = {
458
+ "stdout": subprocess.PIPE,
459
+ "stderr": subprocess.PIPE,
460
+ "creationflags": subprocess.CREATE_NO_WINDOW
461
+ if sys.platform == "win32"
462
+ else 0,
463
+ }
464
+ process_wav = subprocess.Popen(command, **sub_params)
465
+ output, errors = process_wav.communicate()
466
+ if process_wav.returncode != 0 or not os.path.exists(stereo_path):
467
+ raise Exception("Error processing audio to stereo wav")
468
+
469
+ return stereo_path
470
+ else:
471
+ return audio_path
472
+
473
+
474
+ def process_uvr_task(
475
+ orig_song_path: str = "aud_test.mp3",
476
+ main_vocals: bool = False,
477
+ dereverb: bool = True,
478
+ song_id: str = "mdx", # folder output name
479
+ only_voiceless: bool = False,
480
+ remove_files_output_dir: bool = False,
481
+ ):
482
+ if os.environ.get("SONITR_DEVICE") == "cpu":
483
+ device_base = "cpu"
484
+ else:
485
+ device_base = "cuda" if torch.cuda.is_available() else "cpu"
486
+
487
+ if remove_files_output_dir:
488
+ remove_directory_contents(output_dir)
489
+
490
+ with open(os.path.join(mdxnet_models_dir, "data.json")) as infile:
491
+ mdx_model_params = json.load(infile)
492
+
493
+ song_output_dir = os.path.join(output_dir, song_id)
494
+ create_directories(song_output_dir)
495
+ orig_song_path = convert_to_stereo_and_wav(orig_song_path)
496
+
497
+ logger.debug(f"onnxruntime device >> {ort.get_device()}")
498
+
499
+ if only_voiceless:
500
+ logger.info("Voiceless Track Separation...")
501
+ return run_mdx(
502
+ mdx_model_params,
503
+ song_output_dir,
504
+ os.path.join(mdxnet_models_dir, "UVR-MDX-NET-Inst_HQ_4.onnx"),
505
+ orig_song_path,
506
+ suffix="Voiceless",
507
+ denoise=False,
508
+ keep_orig=True,
509
+ exclude_inversion=True,
510
+ device_base=device_base,
511
+ )
512
+
513
+ logger.info("Vocal Track Isolation and Voiceless Track Separation...")
514
+ vocals_path, instrumentals_path = run_mdx(
515
+ mdx_model_params,
516
+ song_output_dir,
517
+ os.path.join(mdxnet_models_dir, "UVR-MDX-NET-Voc_FT.onnx"),
518
+ orig_song_path,
519
+ denoise=True,
520
+ keep_orig=True,
521
+ device_base=device_base,
522
+ )
523
+
524
+ if main_vocals:
525
+ logger.info("Main Voice Separation from Supporting Vocals...")
526
+ backup_vocals_path, main_vocals_path = run_mdx(
527
+ mdx_model_params,
528
+ song_output_dir,
529
+ os.path.join(mdxnet_models_dir, "UVR_MDXNET_KARA_2.onnx"),
530
+ vocals_path,
531
+ suffix="Backup",
532
+ invert_suffix="Main",
533
+ denoise=True,
534
+ device_base=device_base,
535
+ )
536
+ else:
537
+ backup_vocals_path, main_vocals_path = None, vocals_path
538
+
539
+ if dereverb:
540
+ logger.info("Vocal Clarity Enhancement through De-Reverberation...")
541
+ _, vocals_dereverb_path = run_mdx(
542
+ mdx_model_params,
543
+ song_output_dir,
544
+ os.path.join(mdxnet_models_dir, "Reverb_HQ_By_FoxJoy.onnx"),
545
+ main_vocals_path,
546
+ invert_suffix="DeReverb",
547
+ exclude_main=True,
548
+ denoise=True,
549
+ device_base=device_base,
550
+ )
551
+ else:
552
+ vocals_dereverb_path = main_vocals_path
553
+
554
+ return (
555
+ vocals_path,
556
+ instrumentals_path,
557
+ backup_vocals_path,
558
+ main_vocals_path,
559
+ vocals_dereverb_path,
560
+ )
561
+
562
+
563
+ if __name__ == "__main__":
564
+ from utils import download_manager
565
+
566
+ for id_model in UVR_MODELS:
567
+ download_manager(
568
+ os.path.join(MDX_DOWNLOAD_LINK, id_model), mdxnet_models_dir
569
+ )
570
+ (
571
+ vocals_path_,
572
+ instrumentals_path_,
573
+ backup_vocals_path_,
574
+ main_vocals_path_,
575
+ vocals_dereverb_path_,
576
+ ) = process_uvr_task(
577
+ orig_song_path="aud.mp3",
578
+ main_vocals=True,
579
+ dereverb=True,
580
+ song_id="mdx",
581
+ remove_files_output_dir=True,
582
+ )
soni_translate/postprocessor.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils import remove_files, run_command
2
+ from .text_multiformat_processor import get_subtitle
3
+ from .logging_setup import logger
4
+ import unicodedata
5
+ import shutil
6
+ import copy
7
+ import os
8
+ import re
9
+
10
+ OUTPUT_TYPE_OPTIONS = [
11
+ "video (mp4)",
12
+ "video (mkv)",
13
+ "audio (mp3)",
14
+ "audio (ogg)",
15
+ "audio (wav)",
16
+ "subtitle",
17
+ "subtitle [by speaker]",
18
+ "video [subtitled] (mp4)",
19
+ "video [subtitled] (mkv)",
20
+ "audio [original vocal sound]",
21
+ "audio [original background sound]",
22
+ "audio [original vocal and background sound]",
23
+ "audio [original vocal-dereverb sound]",
24
+ "audio [original vocal-dereverb and background sound]",
25
+ "raw media",
26
+ ]
27
+
28
+ DOCS_OUTPUT_TYPE_OPTIONS = [
29
+ "videobook (mp4)",
30
+ "videobook (mkv)",
31
+ "audiobook (wav)",
32
+ "audiobook (mp3)",
33
+ "audiobook (ogg)",
34
+ "book (txt)",
35
+ ] # Add DOCX and etc.
36
+
37
+
38
+ def get_no_ext_filename(file_path):
39
+ file_name_with_extension = os.path.basename(rf"{file_path}")
40
+ filename_without_extension, _ = os.path.splitext(file_name_with_extension)
41
+ return filename_without_extension
42
+
43
+
44
+ def get_video_info(link):
45
+ aux_name = f"video_url_{link}"
46
+ params_dlp = {"quiet": True, "no_warnings": True, "noplaylist": True}
47
+ try:
48
+ from yt_dlp import YoutubeDL
49
+
50
+ with YoutubeDL(params_dlp) as ydl:
51
+ if link.startswith(("www.youtube.com/", "m.youtube.com/")):
52
+ link = "https://" + link
53
+ info_dict = ydl.extract_info(link, download=False, process=False)
54
+ video_id = info_dict.get("id", aux_name)
55
+ video_title = info_dict.get("title", video_id)
56
+ if "youtube.com" in link and "&list=" in link:
57
+ video_title = ydl.extract_info(
58
+ "https://m.youtube.com/watch?v="+video_id,
59
+ download=False,
60
+ process=False
61
+ ).get("title", video_title)
62
+ except Exception as error:
63
+ logger.error(str(error))
64
+ video_title, video_id = aux_name, "NO_ID"
65
+ return video_title, video_id
66
+
67
+
68
+ def sanitize_file_name(file_name):
69
+ # Normalize the string to NFKD form to separate combined
70
+ # characters into base characters and diacritics
71
+ normalized_name = unicodedata.normalize("NFKD", file_name)
72
+ # Replace any non-ASCII characters or special symbols with an underscore
73
+ sanitized_name = re.sub(r"[^\w\s.-]", "_", normalized_name)
74
+ return sanitized_name
75
+
76
+
77
+ def get_output_file(
78
+ original_file,
79
+ new_file_name,
80
+ soft_subtitles,
81
+ output_directory="",
82
+ ):
83
+ directory_base = "." # default directory
84
+
85
+ if output_directory and os.path.isdir(output_directory):
86
+ new_file_path = os.path.join(output_directory, new_file_name)
87
+ else:
88
+ new_file_path = os.path.join(directory_base, "outputs", new_file_name)
89
+ remove_files(new_file_path)
90
+
91
+ cm = None
92
+ if soft_subtitles and original_file.endswith(".mp4"):
93
+ if new_file_path.endswith(".mp4"):
94
+ cm = f'ffmpeg -y -i "{original_file}" -i sub_tra.srt -i sub_ori.srt -map 0:v -map 0:a -map 1 -map 2 -c:v copy -c:a copy -c:s mov_text "{new_file_path}"'
95
+ else:
96
+ cm = f'ffmpeg -y -i "{original_file}" -i sub_tra.srt -i sub_ori.srt -map 0:v -map 0:a -map 1 -map 2 -c:v copy -c:a copy -c:s srt -movflags use_metadata_tags -map_metadata 0 "{new_file_path}"'
97
+ elif new_file_path.endswith(".mkv"):
98
+ cm = f'ffmpeg -i "{original_file}" -c:v copy -c:a copy "{new_file_path}"'
99
+ elif new_file_path.endswith(".wav") and not original_file.endswith(".wav"):
100
+ cm = f'ffmpeg -y -i "{original_file}" -acodec pcm_s16le -ar 44100 -ac 2 "{new_file_path}"'
101
+ elif new_file_path.endswith(".ogg"):
102
+ cm = f'ffmpeg -i "{original_file}" -c:a libvorbis "{new_file_path}"'
103
+ elif new_file_path.endswith(".mp3") and not original_file.endswith(".mp3"):
104
+ cm = f'ffmpeg -y -i "{original_file}" -codec:a libmp3lame -qscale:a 2 "{new_file_path}"'
105
+
106
+ if cm:
107
+ try:
108
+ run_command(cm)
109
+ except Exception as error:
110
+ logger.error(str(error))
111
+ remove_files(new_file_path)
112
+ shutil.copy2(original_file, new_file_path)
113
+ else:
114
+ shutil.copy2(original_file, new_file_path)
115
+
116
+ return os.path.abspath(new_file_path)
117
+
118
+
119
+ def media_out(
120
+ media_file,
121
+ lang_code,
122
+ media_out_name="",
123
+ extension="mp4",
124
+ file_obj="video_dub.mp4",
125
+ soft_subtitles=False,
126
+ subtitle_files="disable",
127
+ ):
128
+ if not media_out_name:
129
+ if os.path.exists(media_file):
130
+ base_name = get_no_ext_filename(media_file)
131
+ else:
132
+ base_name, _ = get_video_info(media_file)
133
+
134
+ media_out_name = f"{base_name}__{lang_code}"
135
+
136
+ f_name = f"{sanitize_file_name(media_out_name)}.{extension}"
137
+
138
+ if subtitle_files != "disable":
139
+ final_media = [get_output_file(file_obj, f_name, soft_subtitles)]
140
+ name_tra = f"{sanitize_file_name(media_out_name)}.{subtitle_files}"
141
+ name_ori = f"{sanitize_file_name(base_name)}.{subtitle_files}"
142
+ tgt_subs = f"sub_tra.{subtitle_files}"
143
+ ori_subs = f"sub_ori.{subtitle_files}"
144
+ final_subtitles = [
145
+ get_output_file(tgt_subs, name_tra, False),
146
+ get_output_file(ori_subs, name_ori, False)
147
+ ]
148
+ return final_media + final_subtitles
149
+ else:
150
+ return get_output_file(file_obj, f_name, soft_subtitles)
151
+
152
+
153
+ def get_subtitle_speaker(media_file, result, language, extension, base_name):
154
+
155
+ segments_base = copy.deepcopy(result)
156
+
157
+ # Sub segments by speaker
158
+ segments_by_speaker = {}
159
+ for segment in segments_base["segments"]:
160
+ if segment["speaker"] not in segments_by_speaker.keys():
161
+ segments_by_speaker[segment["speaker"]] = [segment]
162
+ else:
163
+ segments_by_speaker[segment["speaker"]].append(segment)
164
+
165
+ if not base_name:
166
+ if os.path.exists(media_file):
167
+ base_name = get_no_ext_filename(media_file)
168
+ else:
169
+ base_name, _ = get_video_info(media_file)
170
+
171
+ files_subs = []
172
+ for name_sk, segments in segments_by_speaker.items():
173
+
174
+ subtitle_speaker = get_subtitle(
175
+ language,
176
+ {"segments": segments},
177
+ extension,
178
+ filename=name_sk,
179
+ )
180
+
181
+ media_out_name = f"{base_name}_{language}_{name_sk}"
182
+
183
+ output = media_out(
184
+ media_file, # no need
185
+ language,
186
+ media_out_name,
187
+ extension,
188
+ file_obj=subtitle_speaker,
189
+ )
190
+
191
+ files_subs.append(output)
192
+
193
+ return files_subs
194
+
195
+
196
+ def sound_separate(media_file, task_uvr):
197
+ from .mdx_net import process_uvr_task
198
+
199
+ outputs = []
200
+
201
+ if "vocal" in task_uvr:
202
+ try:
203
+ _, _, _, _, vocal_audio = process_uvr_task(
204
+ orig_song_path=media_file,
205
+ main_vocals=False,
206
+ dereverb=True if "dereverb" in task_uvr else False,
207
+ remove_files_output_dir=True,
208
+ )
209
+ outputs.append(vocal_audio)
210
+ except Exception as error:
211
+ logger.error(str(error))
212
+
213
+ if "background" in task_uvr:
214
+ try:
215
+ background_audio, _ = process_uvr_task(
216
+ orig_song_path=media_file,
217
+ song_id="voiceless",
218
+ only_voiceless=True,
219
+ remove_files_output_dir=False if "vocal" in task_uvr else True,
220
+ )
221
+ # copy_files(background_audio, ".")
222
+ outputs.append(background_audio)
223
+ except Exception as error:
224
+ logger.error(str(error))
225
+
226
+ if not outputs:
227
+ raise Exception("Error in uvr process")
228
+
229
+ return outputs
soni_translate/preprocessor.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils import remove_files
2
+ import os, shutil, subprocess, time, shlex, sys # noqa
3
+ from .logging_setup import logger
4
+ import json
5
+
6
+ ERROR_INCORRECT_CODEC_PARAMETERS = [
7
+ "prores", # mov
8
+ "ffv1", # mkv
9
+ "msmpeg4v3", # avi
10
+ "wmv2", # wmv
11
+ "theora", # ogv
12
+ ] # fix final merge
13
+
14
+ TESTED_CODECS = [
15
+ "h264", # mp4
16
+ "h265", # mp4
17
+ "vp9", # webm
18
+ "mpeg4", # mp4
19
+ "mpeg2video", # mpg
20
+ "mjpeg", # avi
21
+ ]
22
+
23
+
24
+ class OperationFailedError(Exception):
25
+ def __init__(self, message="The operation did not complete successfully."):
26
+ self.message = message
27
+ super().__init__(self.message)
28
+
29
+
30
+ def get_video_codec(video_file):
31
+ command_base = rf'ffprobe -v error -select_streams v:0 -show_entries stream=codec_name -of json "{video_file}"'
32
+ command = shlex.split(command_base)
33
+ try:
34
+ process = subprocess.Popen(
35
+ command,
36
+ stdout=subprocess.PIPE,
37
+ creationflags=subprocess.CREATE_NO_WINDOW if sys.platform == "win32" else 0,
38
+ )
39
+ output, _ = process.communicate()
40
+ codec_info = json.loads(output.decode('utf-8'))
41
+ codec_name = codec_info['streams'][0]['codec_name']
42
+ return codec_name
43
+ except Exception as error:
44
+ logger.debug(str(error))
45
+ return None
46
+
47
+
48
+ def audio_preprocessor(preview, base_audio, audio_wav, use_cuda=False):
49
+ base_audio = base_audio.strip()
50
+ previous_files_to_remove = [audio_wav]
51
+ remove_files(previous_files_to_remove)
52
+
53
+ if preview:
54
+ logger.warning(
55
+ "Creating a preview video of 10 seconds, to disable "
56
+ "this option, go to advanced settings and turn off preview."
57
+ )
58
+ wav_ = f'ffmpeg -y -i "{base_audio}" -ss 00:00:20 -t 00:00:10 -vn -acodec pcm_s16le -ar 44100 -ac 2 audio.wav'
59
+ else:
60
+ wav_ = f'ffmpeg -y -i "{base_audio}" -vn -acodec pcm_s16le -ar 44100 -ac 2 audio.wav'
61
+
62
+ # Run cmd process
63
+ sub_params = {
64
+ "stdout": subprocess.PIPE,
65
+ "stderr": subprocess.PIPE,
66
+ "creationflags": subprocess.CREATE_NO_WINDOW
67
+ if sys.platform == "win32"
68
+ else 0,
69
+ }
70
+ wav_ = shlex.split(wav_)
71
+ result_convert_audio = subprocess.Popen(wav_, **sub_params)
72
+ output, errors = result_convert_audio.communicate()
73
+ time.sleep(1)
74
+ if result_convert_audio.returncode in [1, 2] or not os.path.exists(
75
+ audio_wav
76
+ ):
77
+ raise OperationFailedError(f"Error can't create the audio file:\n{errors.decode('utf-8')}")
78
+
79
+
80
+ def audio_video_preprocessor(
81
+ preview, video, OutputFile, audio_wav, use_cuda=False
82
+ ):
83
+ video = video.strip()
84
+ previous_files_to_remove = [OutputFile, "audio.webm", audio_wav]
85
+ remove_files(previous_files_to_remove)
86
+
87
+ if os.path.exists(video):
88
+ if preview:
89
+ logger.warning(
90
+ "Creating a preview video of 10 seconds, "
91
+ "to disable this option, go to advanced "
92
+ "settings and turn off preview."
93
+ )
94
+ mp4_ = f'ffmpeg -y -i "{video}" -ss 00:00:20 -t 00:00:10 -c:v libx264 -c:a aac -strict experimental Video.mp4'
95
+ else:
96
+ video_codec = get_video_codec(video)
97
+ if not video_codec:
98
+ logger.debug("No video codec found in video")
99
+ else:
100
+ logger.info(f"Video codec: {video_codec}")
101
+
102
+ # Check if the file ends with ".mp4" extension or is valid codec
103
+ if video.endswith(".mp4") or video_codec in TESTED_CODECS:
104
+ destination_path = os.path.join(os.getcwd(), "Video.mp4")
105
+ shutil.copy(video, destination_path)
106
+ time.sleep(0.5)
107
+ if os.path.exists(OutputFile):
108
+ mp4_ = "ffmpeg -h"
109
+ else:
110
+ mp4_ = f'ffmpeg -y -i "{video}" -c copy Video.mp4'
111
+ else:
112
+ logger.warning(
113
+ "File does not have the '.mp4' extension or a "
114
+ "supported codec. Converting video to mp4 (codec: h264)."
115
+ )
116
+ mp4_ = f'ffmpeg -y -i "{video}" -c:v libx264 -c:a aac -strict experimental Video.mp4'
117
+ else:
118
+ if preview:
119
+ logger.warning(
120
+ "Creating a preview from the link, 10 seconds "
121
+ "to disable this option, go to advanced "
122
+ "settings and turn off preview."
123
+ )
124
+ # https://github.com/yt-dlp/yt-dlp/issues/2220
125
+ mp4_ = f'yt-dlp -f "mp4" --downloader ffmpeg --downloader-args "ffmpeg_i: -ss 00:00:20 -t 00:00:10" --force-overwrites --max-downloads 1 --no-warnings --no-playlist --no-abort-on-error --ignore-no-formats-error --restrict-filenames -o {OutputFile} {video}'
126
+ wav_ = "ffmpeg -y -i Video.mp4 -vn -acodec pcm_s16le -ar 44100 -ac 2 audio.wav"
127
+ else:
128
+ mp4_ = f'yt-dlp -f "mp4" --force-overwrites --max-downloads 1 --no-warnings --no-playlist --no-abort-on-error --ignore-no-formats-error --restrict-filenames -o {OutputFile} {video}'
129
+ wav_ = f"python -m yt_dlp --output {audio_wav} --force-overwrites --max-downloads 1 --no-warnings --no-playlist --no-abort-on-error --ignore-no-formats-error --extract-audio --audio-format wav {video}"
130
+
131
+ # Run cmd process
132
+ mp4_ = shlex.split(mp4_)
133
+ sub_params = {
134
+ "stdout": subprocess.PIPE,
135
+ "stderr": subprocess.PIPE,
136
+ "creationflags": subprocess.CREATE_NO_WINDOW
137
+ if sys.platform == "win32"
138
+ else 0,
139
+ }
140
+
141
+ if os.path.exists(video):
142
+ logger.info("Process video...")
143
+ result_convert_video = subprocess.Popen(mp4_, **sub_params)
144
+ # result_convert_video.wait()
145
+ output, errors = result_convert_video.communicate()
146
+ time.sleep(1)
147
+ if result_convert_video.returncode in [1, 2] or not os.path.exists(
148
+ OutputFile
149
+ ):
150
+ raise OperationFailedError(f"Error processing video:\n{errors.decode('utf-8')}")
151
+ logger.info("Process audio...")
152
+ wav_ = "ffmpeg -y -i Video.mp4 -vn -acodec pcm_s16le -ar 44100 -ac 2 audio.wav"
153
+ wav_ = shlex.split(wav_)
154
+ result_convert_audio = subprocess.Popen(wav_, **sub_params)
155
+ output, errors = result_convert_audio.communicate()
156
+ time.sleep(1)
157
+ if result_convert_audio.returncode in [1, 2] or not os.path.exists(
158
+ audio_wav
159
+ ):
160
+ raise OperationFailedError(f"Error can't create the audio file:\n{errors.decode('utf-8')}")
161
+
162
+ else:
163
+ wav_ = shlex.split(wav_)
164
+ if preview:
165
+ result_convert_video = subprocess.Popen(mp4_, **sub_params)
166
+ output, errors = result_convert_video.communicate()
167
+ time.sleep(0.5)
168
+ result_convert_audio = subprocess.Popen(wav_, **sub_params)
169
+ output, errors = result_convert_audio.communicate()
170
+ time.sleep(0.5)
171
+ if result_convert_audio.returncode in [1, 2] or not os.path.exists(
172
+ audio_wav
173
+ ):
174
+ raise OperationFailedError(
175
+ f"Error can't create the preview file:\n{errors.decode('utf-8')}"
176
+ )
177
+ else:
178
+ logger.info("Process audio...")
179
+ result_convert_audio = subprocess.Popen(wav_, **sub_params)
180
+ output, errors = result_convert_audio.communicate()
181
+ time.sleep(1)
182
+ if result_convert_audio.returncode in [1, 2] or not os.path.exists(
183
+ audio_wav
184
+ ):
185
+ raise OperationFailedError(f"Error can't download the audio:\n{errors.decode('utf-8')}")
186
+ logger.info("Process video...")
187
+ result_convert_video = subprocess.Popen(mp4_, **sub_params)
188
+ output, errors = result_convert_video.communicate()
189
+ time.sleep(1)
190
+ if result_convert_video.returncode in [1, 2] or not os.path.exists(
191
+ OutputFile
192
+ ):
193
+ raise OperationFailedError(f"Error can't download the video:\n{errors.decode('utf-8')}")
194
+
195
+
196
+ def old_audio_video_preprocessor(preview, video, OutputFile, audio_wav):
197
+ previous_files_to_remove = [OutputFile, "audio.webm", audio_wav]
198
+ remove_files(previous_files_to_remove)
199
+
200
+ if os.path.exists(video):
201
+ if preview:
202
+ logger.warning(
203
+ "Creating a preview video of 10 seconds, "
204
+ "to disable this option, go to advanced "
205
+ "settings and turn off preview."
206
+ )
207
+ command = f'ffmpeg -y -i "{video}" -ss 00:00:20 -t 00:00:10 -c:v libx264 -c:a aac -strict experimental Video.mp4'
208
+ result_convert_video = subprocess.run(
209
+ command, capture_output=True, text=True, shell=True
210
+ )
211
+ else:
212
+ # Check if the file ends with ".mp4" extension
213
+ if video.endswith(".mp4"):
214
+ destination_path = os.path.join(os.getcwd(), "Video.mp4")
215
+ shutil.copy(video, destination_path)
216
+ result_convert_video = {}
217
+ result_convert_video = subprocess.run(
218
+ "echo Video copied",
219
+ capture_output=True,
220
+ text=True,
221
+ shell=True,
222
+ )
223
+ else:
224
+ logger.warning(
225
+ "File does not have the '.mp4' extension. Converting video."
226
+ )
227
+ command = f'ffmpeg -y -i "{video}" -c:v libx264 -c:a aac -strict experimental Video.mp4'
228
+ result_convert_video = subprocess.run(
229
+ command, capture_output=True, text=True, shell=True
230
+ )
231
+
232
+ if result_convert_video.returncode in [1, 2]:
233
+ raise OperationFailedError("Error can't convert the video")
234
+
235
+ for i in range(120):
236
+ time.sleep(1)
237
+ logger.info("Process video...")
238
+ if os.path.exists(OutputFile):
239
+ time.sleep(1)
240
+ command = "ffmpeg -y -i Video.mp4 -vn -acodec pcm_s16le -ar 44100 -ac 2 audio.wav"
241
+ result_convert_audio = subprocess.run(
242
+ command, capture_output=True, text=True, shell=True
243
+ )
244
+ time.sleep(1)
245
+ break
246
+ if i == 119:
247
+ # if not os.path.exists(OutputFile):
248
+ raise OperationFailedError("Error processing video")
249
+
250
+ if result_convert_audio.returncode in [1, 2]:
251
+ raise OperationFailedError(
252
+ f"Error can't create the audio file: {result_convert_audio.stderr}"
253
+ )
254
+
255
+ for i in range(120):
256
+ time.sleep(1)
257
+ logger.info("Process audio...")
258
+ if os.path.exists(audio_wav):
259
+ break
260
+ if i == 119:
261
+ raise OperationFailedError("Error can't create the audio file")
262
+
263
+ else:
264
+ video = video.strip()
265
+ if preview:
266
+ logger.warning(
267
+ "Creating a preview from the link, 10 "
268
+ "seconds to disable this option, go to "
269
+ "advanced settings and turn off preview."
270
+ )
271
+ # https://github.com/yt-dlp/yt-dlp/issues/2220
272
+ mp4_ = f'yt-dlp -f "mp4" --downloader ffmpeg --downloader-args "ffmpeg_i: -ss 00:00:20 -t 00:00:10" --force-overwrites --max-downloads 1 --no-warnings --no-abort-on-error --ignore-no-formats-error --restrict-filenames -o {OutputFile} {video}'
273
+ wav_ = "ffmpeg -y -i Video.mp4 -vn -acodec pcm_s16le -ar 44100 -ac 2 audio.wav"
274
+ result_convert_video = subprocess.run(
275
+ mp4_, capture_output=True, text=True, shell=True
276
+ )
277
+ result_convert_audio = subprocess.run(
278
+ wav_, capture_output=True, text=True, shell=True
279
+ )
280
+ if result_convert_audio.returncode in [1, 2]:
281
+ raise OperationFailedError("Error can't download a preview")
282
+ else:
283
+ mp4_ = f'yt-dlp -f "mp4" --force-overwrites --max-downloads 1 --no-warnings --no-abort-on-error --ignore-no-formats-error --restrict-filenames -o {OutputFile} {video}'
284
+ wav_ = f"python -m yt_dlp --output {audio_wav} --force-overwrites --max-downloads 1 --no-warnings --no-abort-on-error --ignore-no-formats-error --extract-audio --audio-format wav {video}"
285
+
286
+ result_convert_audio = subprocess.run(
287
+ wav_, capture_output=True, text=True, shell=True
288
+ )
289
+
290
+ if result_convert_audio.returncode in [1, 2]:
291
+ raise OperationFailedError("Error can't download the audio")
292
+
293
+ for i in range(120):
294
+ time.sleep(1)
295
+ logger.info("Process audio...")
296
+ if os.path.exists(audio_wav) and not os.path.exists(
297
+ "audio.webm"
298
+ ):
299
+ time.sleep(1)
300
+ result_convert_video = subprocess.run(
301
+ mp4_, capture_output=True, text=True, shell=True
302
+ )
303
+ break
304
+ if i == 119:
305
+ raise OperationFailedError("Error downloading the audio")
306
+
307
+ if result_convert_video.returncode in [1, 2]:
308
+ raise OperationFailedError("Error can't download the video")
soni_translate/speech_segmentation.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from whisperx.alignment import (
2
+ DEFAULT_ALIGN_MODELS_TORCH as DAMT,
3
+ DEFAULT_ALIGN_MODELS_HF as DAMHF,
4
+ )
5
+ from whisperx.utils import TO_LANGUAGE_CODE
6
+ import whisperx
7
+ import torch
8
+ import gc
9
+ import os
10
+ import soundfile as sf
11
+ from IPython.utils import capture # noqa
12
+ from .language_configuration import EXTRA_ALIGN, INVERTED_LANGUAGES
13
+ from .logging_setup import logger
14
+ from .postprocessor import sanitize_file_name
15
+ from .utils import remove_directory_contents, run_command
16
+
17
+ # ZERO GPU CONFIG
18
+ import spaces
19
+ import copy
20
+ import random
21
+ import time
22
+
23
+ def random_sleep():
24
+ if os.environ.get("ZERO_GPU") == "TRUE":
25
+ print("Random sleep")
26
+ sleep_time = round(random.uniform(7.2, 9.9), 1)
27
+ time.sleep(sleep_time)
28
+
29
+
30
+ @spaces.GPU(duration=120)
31
+ def load_and_transcribe_audio(asr_model, audio, compute_type, language, asr_options, batch_size, segment_duration_limit):
32
+ # Load model
33
+ model = whisperx.load_model(
34
+ asr_model,
35
+ os.environ.get("SONITR_DEVICE") if os.environ.get("ZERO_GPU") != "TRUE" else "cuda",
36
+ compute_type=compute_type,
37
+ language=language,
38
+ asr_options=asr_options,
39
+ )
40
+
41
+ # Transcribe audio
42
+ result = model.transcribe(
43
+ audio,
44
+ batch_size=batch_size,
45
+ chunk_size=segment_duration_limit,
46
+ print_progress=True,
47
+ )
48
+
49
+ del model
50
+ gc.collect()
51
+ torch.cuda.empty_cache() # noqa
52
+
53
+ return result
54
+
55
+ def load_align_and_align_segments(result, audio, DAMHF):
56
+
57
+ # Load alignment model
58
+ model_a, metadata = whisperx.load_align_model(
59
+ language_code=result["language"],
60
+ device=os.environ.get("SONITR_DEVICE") if os.environ.get("ZERO_GPU") != "TRUE" else "cuda",
61
+ model_name=None
62
+ if result["language"] in DAMHF.keys()
63
+ else EXTRA_ALIGN[result["language"]],
64
+ )
65
+
66
+ # Align segments
67
+ alignment_result = whisperx.align(
68
+ result["segments"],
69
+ model_a,
70
+ metadata,
71
+ audio,
72
+ os.environ.get("SONITR_DEVICE") if os.environ.get("ZERO_GPU") != "TRUE" else "cuda",
73
+ return_char_alignments=True,
74
+ print_progress=False,
75
+ )
76
+
77
+ # Clean up
78
+ del model_a
79
+ gc.collect()
80
+ torch.cuda.empty_cache() # noqa
81
+
82
+ return alignment_result
83
+
84
+ @spaces.GPU(duration=120)
85
+ def diarize_audio(diarize_model, audio_wav, min_speakers, max_speakers):
86
+
87
+ if os.environ.get("ZERO_GPU") == "TRUE":
88
+ diarize_model.model.to(torch.device("cuda"))
89
+ diarize_segments = diarize_model(
90
+ audio_wav,
91
+ min_speakers=min_speakers,
92
+ max_speakers=max_speakers
93
+ )
94
+ return diarize_segments
95
+
96
+ # ZERO GPU CONFIG
97
+
98
+ ASR_MODEL_OPTIONS = [
99
+ "tiny",
100
+ "base",
101
+ "small",
102
+ "medium",
103
+ "large",
104
+ "large-v1",
105
+ "large-v2",
106
+ "large-v3",
107
+ "distil-large-v2",
108
+ "Systran/faster-distil-whisper-large-v3",
109
+ "tiny.en",
110
+ "base.en",
111
+ "small.en",
112
+ "medium.en",
113
+ "distil-small.en",
114
+ "distil-medium.en",
115
+ "OpenAI_API_Whisper",
116
+ ]
117
+
118
+ COMPUTE_TYPE_GPU = [
119
+ "default",
120
+ "auto",
121
+ "int8",
122
+ "int8_float32",
123
+ "int8_float16",
124
+ "int8_bfloat16",
125
+ "float16",
126
+ "bfloat16",
127
+ "float32"
128
+ ]
129
+
130
+ COMPUTE_TYPE_CPU = [
131
+ "default",
132
+ "auto",
133
+ "int8",
134
+ "int8_float32",
135
+ "int16",
136
+ "float32",
137
+ ]
138
+
139
+ WHISPER_MODELS_PATH = './WHISPER_MODELS'
140
+
141
+
142
+ def openai_api_whisper(
143
+ input_audio_file,
144
+ source_lang=None,
145
+ chunk_duration=1800
146
+ ):
147
+
148
+ info = sf.info(input_audio_file)
149
+ duration = info.duration
150
+
151
+ output_directory = "./whisper_api_audio_parts"
152
+ os.makedirs(output_directory, exist_ok=True)
153
+ remove_directory_contents(output_directory)
154
+
155
+ if duration > chunk_duration:
156
+ # Split the audio file into smaller chunks with 30-minute duration
157
+ cm = f'ffmpeg -i "{input_audio_file}" -f segment -segment_time {chunk_duration} -c:a libvorbis "{output_directory}/output%03d.ogg"'
158
+ run_command(cm)
159
+ # Get list of generated chunk files
160
+ chunk_files = sorted(
161
+ [f"{output_directory}/{f}" for f in os.listdir(output_directory) if f.endswith('.ogg')]
162
+ )
163
+ else:
164
+ one_file = f"{output_directory}/output000.ogg"
165
+ cm = f'ffmpeg -i "{input_audio_file}" -c:a libvorbis {one_file}'
166
+ run_command(cm)
167
+ chunk_files = [one_file]
168
+
169
+ # Transcript
170
+ segments = []
171
+ language = source_lang if source_lang else None
172
+ for i, chunk in enumerate(chunk_files):
173
+ from openai import OpenAI
174
+ client = OpenAI()
175
+
176
+ audio_file = open(chunk, "rb")
177
+ transcription = client.audio.transcriptions.create(
178
+ model="whisper-1",
179
+ file=audio_file,
180
+ language=language,
181
+ response_format="verbose_json",
182
+ timestamp_granularities=["segment"],
183
+ )
184
+
185
+ try:
186
+ transcript_dict = transcription.model_dump()
187
+ except: # noqa
188
+ transcript_dict = transcription.to_dict()
189
+
190
+ if language is None:
191
+ logger.info(f'Language detected: {transcript_dict["language"]}')
192
+ language = TO_LANGUAGE_CODE[transcript_dict["language"]]
193
+
194
+ chunk_time = chunk_duration * (i)
195
+
196
+ for seg in transcript_dict["segments"]:
197
+
198
+ if "start" in seg.keys():
199
+ segments.append(
200
+ {
201
+ "text": seg["text"],
202
+ "start": seg["start"] + chunk_time,
203
+ "end": seg["end"] + chunk_time,
204
+ }
205
+ )
206
+
207
+ audio = whisperx.load_audio(input_audio_file)
208
+ result = {"segments": segments, "language": language}
209
+
210
+ return audio, result
211
+
212
+
213
+ def find_whisper_models():
214
+ path = WHISPER_MODELS_PATH
215
+ folders = []
216
+
217
+ if os.path.exists(path):
218
+ for folder in os.listdir(path):
219
+ folder_path = os.path.join(path, folder)
220
+ if (
221
+ os.path.isdir(folder_path)
222
+ and 'model.bin' in os.listdir(folder_path)
223
+ ):
224
+ folders.append(folder)
225
+ return folders
226
+
227
+ def transcribe_speech(
228
+ audio_wav,
229
+ asr_model,
230
+ compute_type,
231
+ batch_size,
232
+ SOURCE_LANGUAGE,
233
+ literalize_numbers=True,
234
+ segment_duration_limit=15,
235
+ ):
236
+ """
237
+ Transcribe speech using a whisper model.
238
+
239
+ Parameters:
240
+ - audio_wav (str): Path to the audio file in WAV format.
241
+ - asr_model (str): The whisper model to be loaded.
242
+ - compute_type (str): Type of compute to be used (e.g., 'int8', 'float16').
243
+ - batch_size (int): Batch size for transcription.
244
+ - SOURCE_LANGUAGE (str): Source language for transcription.
245
+
246
+ Returns:
247
+ - Tuple containing:
248
+ - audio: Loaded audio file.
249
+ - result: Transcription result as a dictionary.
250
+ """
251
+
252
+ if asr_model == "OpenAI_API_Whisper":
253
+ if literalize_numbers:
254
+ logger.info(
255
+ "OpenAI's API Whisper does not support "
256
+ "the literalization of numbers."
257
+ )
258
+ return openai_api_whisper(audio_wav, SOURCE_LANGUAGE)
259
+
260
+ # https://github.com/openai/whisper/discussions/277
261
+ prompt = "以下是普通话的句子。" if SOURCE_LANGUAGE == "zh" else None
262
+ SOURCE_LANGUAGE = (
263
+ SOURCE_LANGUAGE if SOURCE_LANGUAGE != "zh-TW" else "zh"
264
+ )
265
+ asr_options = {
266
+ "initial_prompt": prompt,
267
+ "suppress_numerals": literalize_numbers
268
+ }
269
+
270
+ if asr_model not in ASR_MODEL_OPTIONS:
271
+
272
+ base_dir = WHISPER_MODELS_PATH
273
+ if not os.path.exists(base_dir):
274
+ os.makedirs(base_dir)
275
+ model_dir = os.path.join(base_dir, sanitize_file_name(asr_model))
276
+
277
+ if not os.path.exists(model_dir):
278
+ from ctranslate2.converters import TransformersConverter
279
+
280
+ quantization = "float32"
281
+ # Download new model
282
+ try:
283
+ converter = TransformersConverter(
284
+ asr_model,
285
+ low_cpu_mem_usage=True,
286
+ copy_files=[
287
+ "tokenizer_config.json", "preprocessor_config.json"
288
+ ]
289
+ )
290
+ converter.convert(
291
+ model_dir,
292
+ quantization=quantization,
293
+ force=False
294
+ )
295
+ except Exception as error:
296
+ if "File tokenizer_config.json does not exist" in str(error):
297
+ converter._copy_files = [
298
+ "tokenizer.json", "preprocessor_config.json"
299
+ ]
300
+ converter.convert(
301
+ model_dir,
302
+ quantization=quantization,
303
+ force=True
304
+ )
305
+ else:
306
+ raise error
307
+
308
+ asr_model = model_dir
309
+ logger.info(f"ASR Model: {str(model_dir)}")
310
+
311
+ audio = whisperx.load_audio(audio_wav)
312
+
313
+ result = load_and_transcribe_audio(
314
+ asr_model, audio, compute_type, SOURCE_LANGUAGE, asr_options, batch_size, segment_duration_limit
315
+ )
316
+
317
+ if result["language"] == "zh" and not prompt:
318
+ result["language"] = "zh-TW"
319
+ logger.info("Chinese - Traditional (zh-TW)")
320
+
321
+
322
+ return audio, result
323
+
324
+
325
+ def align_speech(audio, result):
326
+ """
327
+ Aligns speech segments based on the provided audio and result metadata.
328
+
329
+ Parameters:
330
+ - audio (array): The audio data in a suitable format for alignment.
331
+ - result (dict): Metadata containing information about the segments
332
+ and language.
333
+
334
+ Returns:
335
+ - result (dict): Updated metadata after aligning the segments with
336
+ the audio. This includes character-level alignments if
337
+ 'return_char_alignments' is set to True.
338
+
339
+ Notes:
340
+ - This function uses language-specific models to align speech segments.
341
+ - It performs language compatibility checks and selects the
342
+ appropriate alignment model.
343
+ - Cleans up memory by releasing resources after alignment.
344
+ """
345
+ DAMHF.update(DAMT) # lang align
346
+ if (
347
+ not result["language"] in DAMHF.keys()
348
+ and not result["language"] in EXTRA_ALIGN.keys()
349
+ ):
350
+ logger.warning(
351
+ "Automatic detection: Source language not compatible with align"
352
+ )
353
+ raise ValueError(
354
+ f"Detected language {result['language']} incompatible, "
355
+ "you can select the source language to avoid this error."
356
+ )
357
+ if (
358
+ result["language"] in EXTRA_ALIGN.keys()
359
+ and EXTRA_ALIGN[result["language"]] == ""
360
+ ):
361
+ lang_name = (
362
+ INVERTED_LANGUAGES[result["language"]]
363
+ if result["language"] in INVERTED_LANGUAGES.keys()
364
+ else result["language"]
365
+ )
366
+ logger.warning(
367
+ "No compatible wav2vec2 model found "
368
+ f"for the language '{lang_name}', skipping alignment."
369
+ )
370
+ return result
371
+
372
+ random_sleep()
373
+ result = load_align_and_align_segments(result, audio, DAMHF)
374
+
375
+ return result
376
+
377
+
378
+ diarization_models = {
379
+ "pyannote_3.1": "pyannote/speaker-diarization-3.1",
380
+ "pyannote_2.1": "pyannote/speaker-diarization@2.1",
381
+ "disable": "",
382
+ }
383
+
384
+
385
+ def reencode_speakers(result):
386
+
387
+ if result["segments"][0]["speaker"] == "SPEAKER_00":
388
+ return result
389
+
390
+ speaker_mapping = {}
391
+ counter = 0
392
+
393
+ logger.debug("Reencode speakers")
394
+
395
+ for segment in result["segments"]:
396
+ old_speaker = segment["speaker"]
397
+ if old_speaker not in speaker_mapping:
398
+ speaker_mapping[old_speaker] = f"SPEAKER_{counter:02d}"
399
+ counter += 1
400
+ segment["speaker"] = speaker_mapping[old_speaker]
401
+
402
+ return result
403
+
404
+
405
+ def diarize_speech(
406
+ audio_wav,
407
+ result,
408
+ min_speakers,
409
+ max_speakers,
410
+ YOUR_HF_TOKEN,
411
+ model_name="pyannote/speaker-diarization@2.1",
412
+ ):
413
+ """
414
+ Performs speaker diarization on speech segments.
415
+
416
+ Parameters:
417
+ - audio_wav (array): Audio data in WAV format to perform speaker
418
+ diarization.
419
+ - result (dict): Metadata containing information about speech segments
420
+ and alignments.
421
+ - min_speakers (int): Minimum number of speakers expected in the audio.
422
+ - max_speakers (int): Maximum number of speakers expected in the audio.
423
+ - YOUR_HF_TOKEN (str): Your Hugging Face API token for model
424
+ authentication.
425
+ - model_name (str): Name of the speaker diarization model to be used
426
+ (default: "pyannote/speaker-diarization@2.1").
427
+
428
+ Returns:
429
+ - result_diarize (dict): Updated metadata after assigning speaker
430
+ labels to segments.
431
+
432
+ Notes:
433
+ - This function utilizes a speaker diarization model to label speaker
434
+ segments in the audio.
435
+ - It assigns speakers to word-level segments based on diarization results.
436
+ - Cleans up memory by releasing resources after diarization.
437
+ - If only one speaker is specified, each segment is automatically assigned
438
+ as the first speaker, eliminating the need for diarization inference.
439
+ """
440
+
441
+ if max(min_speakers, max_speakers) > 1 and model_name:
442
+ try:
443
+
444
+ diarize_model = whisperx.DiarizationPipeline(
445
+ model_name=model_name,
446
+ use_auth_token=YOUR_HF_TOKEN,
447
+ device=os.environ.get("SONITR_DEVICE"),
448
+ )
449
+
450
+ except Exception as error:
451
+ error_str = str(error)
452
+ gc.collect()
453
+ torch.cuda.empty_cache() # noqa
454
+ if "'NoneType' object has no attribute 'to'" in error_str:
455
+ if model_name == diarization_models["pyannote_2.1"]:
456
+ raise ValueError(
457
+ "Accept the license agreement for using Pyannote 2.1."
458
+ " You need to have an account on Hugging Face and "
459
+ "accept the license to use the models: "
460
+ "https://huggingface.co/pyannote/speaker-diarization "
461
+ "and https://huggingface.co/pyannote/segmentation "
462
+ "Get your KEY TOKEN here: "
463
+ "https://hf.co/settings/tokens "
464
+ )
465
+ elif model_name == diarization_models["pyannote_3.1"]:
466
+ raise ValueError(
467
+ "New Licence Pyannote 3.1: You need to have an account"
468
+ " on Hugging Face and accept the license to use the "
469
+ "models: https://huggingface.co/pyannote/speaker-diarization-3.1 " # noqa
470
+ "and https://huggingface.co/pyannote/segmentation-3.0 "
471
+ )
472
+ else:
473
+ raise error
474
+
475
+ random_sleep()
476
+ diarize_segments = diarize_audio(diarize_model, audio_wav, min_speakers, max_speakers)
477
+
478
+ result_diarize = whisperx.assign_word_speakers(
479
+ diarize_segments, result
480
+ )
481
+
482
+ for segment in result_diarize["segments"]:
483
+ if "speaker" not in segment:
484
+ segment["speaker"] = "SPEAKER_00"
485
+ logger.warning(
486
+ f"No speaker detected in {segment['start']}. First TTS "
487
+ f"will be used for the segment text: {segment['text']} "
488
+ )
489
+
490
+ del diarize_model
491
+ gc.collect()
492
+ torch.cuda.empty_cache() # noqa
493
+ else:
494
+ result_diarize = result
495
+ result_diarize["segments"] = [
496
+ {**item, "speaker": "SPEAKER_00"}
497
+ for item in result_diarize["segments"]
498
+ ]
499
+ return reencode_speakers(result_diarize)
soni_translate/text_multiformat_processor.py ADDED
@@ -0,0 +1,987 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .logging_setup import logger
2
+ from whisperx.utils import get_writer
3
+ from .utils import remove_files, run_command, remove_directory_contents
4
+ from typing import List
5
+ import srt
6
+ import re
7
+ import os
8
+ import copy
9
+ import string
10
+ import soundfile as sf
11
+ from PIL import Image, ImageOps, ImageDraw, ImageFont
12
+
13
+ punctuation_list = list(
14
+ string.punctuation + "¡¿«»„”“”‚‘’「」『』《》()【】〈〉〔〕〖〗〘〙〚〛⸤⸥⸨⸩"
15
+ )
16
+ symbol_list = punctuation_list + ["", "..", "..."]
17
+
18
+
19
+ def extract_from_srt(file_path):
20
+ with open(file_path, "r", encoding="utf-8") as file:
21
+ srt_content = file.read()
22
+
23
+ subtitle_generator = srt.parse(srt_content)
24
+ srt_content_list = list(subtitle_generator)
25
+
26
+ return srt_content_list
27
+
28
+
29
+ def clean_text(text):
30
+
31
+ # Remove content within square brackets
32
+ text = re.sub(r'\[.*?\]', '', text)
33
+ # Add pattern to remove content within <comment> tags
34
+ text = re.sub(r'<comment>.*?</comment>', '', text)
35
+ # Remove HTML tags
36
+ text = re.sub(r'<.*?>', '', text)
37
+ # Remove "♫" and "♪" content
38
+ text = re.sub(r'♫.*?♫', '', text)
39
+ text = re.sub(r'♪.*?♪', '', text)
40
+ # Replace newline characters with an empty string
41
+ text = text.replace("\n", ". ")
42
+ # Remove double quotation marks
43
+ text = text.replace('"', '')
44
+ # Collapse multiple spaces and replace with a single space
45
+ text = re.sub(r"\s+", " ", text)
46
+ # Normalize spaces around periods
47
+ text = re.sub(r"[\s\.]+(?=\s)", ". ", text)
48
+ # Check if there are ♫ or ♪ symbols present
49
+ if '♫' in text or '♪' in text:
50
+ return ""
51
+
52
+ text = text.strip()
53
+
54
+ # Valid text
55
+ return text if text not in symbol_list else ""
56
+
57
+
58
+ def srt_file_to_segments(file_path, speaker=False):
59
+ try:
60
+ srt_content_list = extract_from_srt(file_path)
61
+ except Exception as error:
62
+ logger.error(str(error))
63
+ fixed_file = "fixed_sub.srt"
64
+ remove_files(fixed_file)
65
+ fix_sub = f'ffmpeg -i "{file_path}" "{fixed_file}" -y'
66
+ run_command(fix_sub)
67
+ srt_content_list = extract_from_srt(fixed_file)
68
+
69
+ segments = []
70
+ for segment in srt_content_list:
71
+
72
+ text = clean_text(str(segment.content))
73
+
74
+ if text:
75
+ segments.append(
76
+ {
77
+ "text": text,
78
+ "start": float(segment.start.total_seconds()),
79
+ "end": float(segment.end.total_seconds()),
80
+ }
81
+ )
82
+
83
+ if not segments:
84
+ raise Exception("No data found in srt subtitle file")
85
+
86
+ if speaker:
87
+ segments = [{**seg, "speaker": "SPEAKER_00"} for seg in segments]
88
+
89
+ return {"segments": segments}
90
+
91
+
92
+ # documents
93
+
94
+
95
+ def dehyphenate(lines: List[str], line_no: int) -> List[str]:
96
+ next_line = lines[line_no + 1]
97
+ word_suffix = next_line.split(" ")[0]
98
+
99
+ lines[line_no] = lines[line_no][:-1] + word_suffix
100
+ lines[line_no + 1] = lines[line_no + 1][len(word_suffix):]
101
+ return lines
102
+
103
+
104
+ def remove_hyphens(text: str) -> str:
105
+ """
106
+
107
+ This fails for:
108
+ * Natural dashes: well-known, self-replication, use-cases, non-semantic,
109
+ Post-processing, Window-wise, viewpoint-dependent
110
+ * Trailing math operands: 2 - 4
111
+ * Names: Lopez-Ferreras, VGG-19, CIFAR-100
112
+ """
113
+ lines = [line.rstrip() for line in text.split("\n")]
114
+
115
+ # Find dashes
116
+ line_numbers = []
117
+ for line_no, line in enumerate(lines[:-1]):
118
+ if line.endswith("-"):
119
+ line_numbers.append(line_no)
120
+
121
+ # Replace
122
+ for line_no in line_numbers:
123
+ lines = dehyphenate(lines, line_no)
124
+
125
+ return "\n".join(lines)
126
+
127
+
128
+ def pdf_to_txt(pdf_file, start_page, end_page):
129
+ from pypdf import PdfReader
130
+
131
+ with open(pdf_file, "rb") as file:
132
+ reader = PdfReader(file)
133
+ logger.debug(f"Total pages: {reader.get_num_pages()}")
134
+ text = ""
135
+
136
+ start_page_idx = max((start_page-1), 0)
137
+ end_page_inx = min((end_page), (reader.get_num_pages()))
138
+ document_pages = reader.pages[start_page_idx:end_page_inx]
139
+ logger.info(
140
+ f"Selected pages from {start_page_idx} to {end_page_inx}: "
141
+ f"{len(document_pages)}"
142
+ )
143
+
144
+ for page in document_pages:
145
+ text += remove_hyphens(page.extract_text())
146
+ return text
147
+
148
+
149
+ def docx_to_txt(docx_file):
150
+ # https://github.com/AlJohri/docx2pdf update
151
+ from docx import Document
152
+
153
+ doc = Document(docx_file)
154
+ text = ""
155
+ for paragraph in doc.paragraphs:
156
+ text += paragraph.text + "\n"
157
+ return text
158
+
159
+
160
+ def replace_multiple_elements(text, replacements):
161
+ pattern = re.compile("|".join(map(re.escape, replacements.keys())))
162
+ replaced_text = pattern.sub(
163
+ lambda match: replacements[match.group(0)], text
164
+ )
165
+
166
+ # Remove multiple spaces
167
+ replaced_text = re.sub(r"\s+", " ", replaced_text)
168
+
169
+ return replaced_text
170
+
171
+
172
+ def document_preprocessor(file_path, is_string, start_page, end_page):
173
+ if not is_string:
174
+ file_ext = os.path.splitext(file_path)[1].lower()
175
+
176
+ if is_string:
177
+ text = file_path
178
+ elif file_ext == ".pdf":
179
+ text = pdf_to_txt(file_path, start_page, end_page)
180
+ elif file_ext == ".docx":
181
+ text = docx_to_txt(file_path)
182
+ elif file_ext == ".txt":
183
+ with open(
184
+ file_path, "r", encoding='utf-8', errors='replace'
185
+ ) as file:
186
+ text = file.read()
187
+ else:
188
+ raise Exception("Unsupported file format")
189
+
190
+ # Add space to break segments more easily later
191
+ replacements = {
192
+ "、": "、 ",
193
+ "。": "。 ",
194
+ # "\n": " ",
195
+ }
196
+ text = replace_multiple_elements(text, replacements)
197
+
198
+ # Save text to a .txt file
199
+ # file_name = os.path.splitext(os.path.basename(file_path))[0]
200
+ txt_file_path = "./text_preprocessor.txt"
201
+
202
+ with open(
203
+ txt_file_path, "w", encoding='utf-8', errors='replace'
204
+ ) as txt_file:
205
+ txt_file.write(text)
206
+
207
+ return txt_file_path, text
208
+
209
+
210
+ def split_text_into_chunks(text, chunk_size):
211
+ words = re.findall(r"\b\w+\b", text)
212
+ chunks = []
213
+ current_chunk = ""
214
+ for word in words:
215
+ if (
216
+ len(current_chunk) + len(word) + 1 <= chunk_size
217
+ ): # Adding 1 for the space between words
218
+ if current_chunk:
219
+ current_chunk += " "
220
+ current_chunk += word
221
+ else:
222
+ chunks.append(current_chunk)
223
+ current_chunk = word
224
+ if current_chunk:
225
+ chunks.append(current_chunk)
226
+ return chunks
227
+
228
+
229
+ def determine_chunk_size(file_name):
230
+ patterns = {
231
+ re.compile(r".*-(Male|Female)$"): 1024, # by character
232
+ re.compile(r".* BARK$"): 100, # t 64 256
233
+ re.compile(r".* VITS$"): 500,
234
+ re.compile(
235
+ r".+\.(wav|mp3|ogg|m4a)$"
236
+ ): 150, # t 250 400 api automatic split
237
+ re.compile(r".* VITS-onnx$"): 250, # automatic sentence split
238
+ re.compile(r".* OpenAI-TTS$"): 1024 # max charaters 4096
239
+ }
240
+
241
+ for pattern, chunk_size in patterns.items():
242
+ if pattern.match(file_name):
243
+ return chunk_size
244
+
245
+ # Default chunk size if the file doesn't match any pattern; max 1800
246
+ return 100
247
+
248
+
249
+ def plain_text_to_segments(result_text=None, chunk_size=None):
250
+ if not chunk_size:
251
+ chunk_size = 100
252
+ text_chunks = split_text_into_chunks(result_text, chunk_size)
253
+
254
+ segments_chunks = []
255
+ for num, chunk in enumerate(text_chunks):
256
+ chunk_dict = {
257
+ "text": chunk,
258
+ "start": (1.0 + num),
259
+ "end": (2.0 + num),
260
+ "speaker": "SPEAKER_00",
261
+ }
262
+ segments_chunks.append(chunk_dict)
263
+
264
+ result_diarize = {"segments": segments_chunks}
265
+
266
+ return result_diarize
267
+
268
+
269
+ def segments_to_plain_text(result_diarize):
270
+ complete_text = ""
271
+ for seg in result_diarize["segments"]:
272
+ complete_text += seg["text"] + " " # issue
273
+
274
+ # Save text to a .txt file
275
+ # file_name = os.path.splitext(os.path.basename(file_path))[0]
276
+ txt_file_path = "./text_translation.txt"
277
+
278
+ with open(
279
+ txt_file_path, "w", encoding='utf-8', errors='replace'
280
+ ) as txt_file:
281
+ txt_file.write(complete_text)
282
+
283
+ return txt_file_path, complete_text
284
+
285
+
286
+ # doc to video
287
+
288
+ COLORS = {
289
+ "black": (0, 0, 0),
290
+ "white": (255, 255, 255),
291
+ "red": (255, 0, 0),
292
+ "green": (0, 255, 0),
293
+ "blue": (0, 0, 255),
294
+ "yellow": (255, 255, 0),
295
+ "light_gray": (200, 200, 200),
296
+ "light_blue": (173, 216, 230),
297
+ "light_green": (144, 238, 144),
298
+ "light_yellow": (255, 255, 224),
299
+ "light_pink": (255, 182, 193),
300
+ "lavender": (230, 230, 250),
301
+ "peach": (255, 218, 185),
302
+ "light_cyan": (224, 255, 255),
303
+ "light_salmon": (255, 160, 122),
304
+ "light_green_yellow": (173, 255, 47),
305
+ }
306
+
307
+ BORDER_COLORS = ["dynamic"] + list(COLORS.keys())
308
+
309
+
310
+ def calculate_average_color(img):
311
+ # Resize the image to a small size for faster processing
312
+ img_small = img.resize((50, 50))
313
+ # Calculate the average color
314
+ average_color = img_small.convert("RGB").resize((1, 1)).getpixel((0, 0))
315
+ return average_color
316
+
317
+
318
+ def add_border_to_image(
319
+ image_path,
320
+ target_width,
321
+ target_height,
322
+ border_color=None
323
+ ):
324
+
325
+ img = Image.open(image_path)
326
+
327
+ # Calculate the width and height for the new image with borders
328
+ original_width, original_height = img.size
329
+ original_aspect_ratio = original_width / original_height
330
+ target_aspect_ratio = target_width / target_height
331
+
332
+ # Resize the image to fit the target resolution retaining aspect ratio
333
+ if original_aspect_ratio > target_aspect_ratio:
334
+ # Image is wider, calculate new height
335
+ new_height = int(target_width / original_aspect_ratio)
336
+ resized_img = img.resize((target_width, new_height))
337
+ else:
338
+ # Image is taller, calculate new width
339
+ new_width = int(target_height * original_aspect_ratio)
340
+ resized_img = img.resize((new_width, target_height))
341
+
342
+ # Calculate padding for borders
343
+ padding = (0, 0, 0, 0)
344
+ if resized_img.size[0] != target_width or resized_img.size[1] != target_height:
345
+ if original_aspect_ratio > target_aspect_ratio:
346
+ # Add borders vertically
347
+ padding = (0, (target_height - resized_img.size[1]) // 2, 0, (target_height - resized_img.size[1]) // 2)
348
+ else:
349
+ # Add borders horizontally
350
+ padding = ((target_width - resized_img.size[0]) // 2, 0, (target_width - resized_img.size[0]) // 2, 0)
351
+
352
+ # Add borders with specified color
353
+ if not border_color or border_color == "dynamic":
354
+ border_color = calculate_average_color(resized_img)
355
+ else:
356
+ border_color = COLORS.get(border_color, (0, 0, 0))
357
+
358
+ bordered_img = ImageOps.expand(resized_img, padding, fill=border_color)
359
+
360
+ bordered_img.save(image_path)
361
+
362
+ return image_path
363
+
364
+
365
+ def resize_and_position_subimage(
366
+ subimage,
367
+ max_width,
368
+ max_height,
369
+ subimage_position,
370
+ main_width,
371
+ main_height
372
+ ):
373
+ subimage_width, subimage_height = subimage.size
374
+
375
+ # Resize subimage if it exceeds maximum dimensions
376
+ if subimage_width > max_width or subimage_height > max_height:
377
+ # Calculate scaling factor
378
+ width_scale = max_width / subimage_width
379
+ height_scale = max_height / subimage_height
380
+ scale = min(width_scale, height_scale)
381
+
382
+ # Resize subimage
383
+ subimage = subimage.resize(
384
+ (int(subimage_width * scale), int(subimage_height * scale))
385
+ )
386
+
387
+ # Calculate position to place the subimage
388
+ if subimage_position == "top-left":
389
+ subimage_x = 0
390
+ subimage_y = 0
391
+ elif subimage_position == "top-right":
392
+ subimage_x = main_width - subimage.width
393
+ subimage_y = 0
394
+ elif subimage_position == "bottom-left":
395
+ subimage_x = 0
396
+ subimage_y = main_height - subimage.height
397
+ elif subimage_position == "bottom-right":
398
+ subimage_x = main_width - subimage.width
399
+ subimage_y = main_height - subimage.height
400
+ else:
401
+ raise ValueError(
402
+ "Invalid subimage_position. Choose from 'top-left', 'top-right',"
403
+ " 'bottom-left', or 'bottom-right'."
404
+ )
405
+
406
+ return subimage, subimage_x, subimage_y
407
+
408
+
409
+ def create_image_with_text_and_subimages(
410
+ text,
411
+ subimages,
412
+ width,
413
+ height,
414
+ text_color,
415
+ background_color,
416
+ output_file
417
+ ):
418
+ # Create an image with the specified resolution and background color
419
+ image = Image.new('RGB', (width, height), color=background_color)
420
+
421
+ # Initialize ImageDraw object
422
+ draw = ImageDraw.Draw(image)
423
+
424
+ # Load a font
425
+ font = ImageFont.load_default() # You can specify your font file here
426
+
427
+ # Calculate text size and position
428
+ text_bbox = draw.textbbox((0, 0), text, font=font)
429
+ text_width = text_bbox[2] - text_bbox[0]
430
+ text_height = text_bbox[3] - text_bbox[1]
431
+ text_x = (width - text_width) / 2
432
+ text_y = (height - text_height) / 2
433
+
434
+ # Draw text on the image
435
+ draw.text((text_x, text_y), text, fill=text_color, font=font)
436
+
437
+ # Paste subimages onto the main image
438
+ for subimage_path, subimage_position in subimages:
439
+ # Open the subimage
440
+ subimage = Image.open(subimage_path)
441
+
442
+ # Convert subimage to RGBA mode if it doesn't have an alpha channel
443
+ if subimage.mode != 'RGBA':
444
+ subimage = subimage.convert('RGBA')
445
+
446
+ # Resize and position the subimage
447
+ subimage, subimage_x, subimage_y = resize_and_position_subimage(
448
+ subimage, width / 4, height / 4, subimage_position, width, height
449
+ )
450
+
451
+ # Paste the subimage onto the main image
452
+ image.paste(subimage, (int(subimage_x), int(subimage_y)), subimage)
453
+
454
+ image.save(output_file)
455
+
456
+ return output_file
457
+
458
+
459
+ def doc_to_txtximg_pages(
460
+ document,
461
+ width,
462
+ height,
463
+ start_page,
464
+ end_page,
465
+ bcolor
466
+ ):
467
+ from pypdf import PdfReader
468
+
469
+ images_folder = "pdf_images/"
470
+ os.makedirs(images_folder, exist_ok=True)
471
+ remove_directory_contents(images_folder)
472
+
473
+ # First image
474
+ text_image = os.path.basename(document)[:-4]
475
+ subimages = [("./assets/logo.jpeg", "top-left")]
476
+ text_color = (255, 255, 255) if bcolor == "black" else (0, 0, 0) # w|b
477
+ background_color = COLORS.get(bcolor, (255, 255, 255)) # dynamic white
478
+ first_image = "pdf_images/0000_00_aaa.png"
479
+
480
+ create_image_with_text_and_subimages(
481
+ text_image,
482
+ subimages,
483
+ width,
484
+ height,
485
+ text_color,
486
+ background_color,
487
+ first_image
488
+ )
489
+
490
+ reader = PdfReader(document)
491
+ logger.debug(f"Total pages: {reader.get_num_pages()}")
492
+
493
+ start_page_idx = max((start_page-1), 0)
494
+ end_page_inx = min((end_page), (reader.get_num_pages()))
495
+ document_pages = reader.pages[start_page_idx:end_page_inx]
496
+
497
+ logger.info(
498
+ f"Selected pages from {start_page_idx} to {end_page_inx}: "
499
+ f"{len(document_pages)}"
500
+ )
501
+
502
+ data_doc = {}
503
+ for i, page in enumerate(document_pages):
504
+
505
+ count = 0
506
+ images = []
507
+ for image_file_object in page.images:
508
+ img_name = f"{images_folder}{i:04d}_{count:02d}_{image_file_object.name}"
509
+ images.append(img_name)
510
+ with open(img_name, "wb") as fp:
511
+ fp.write(image_file_object.data)
512
+ count += 1
513
+ img_name = add_border_to_image(img_name, width, height, bcolor)
514
+
515
+ data_doc[i] = {
516
+ "text": remove_hyphens(page.extract_text()),
517
+ "images": images
518
+ }
519
+
520
+ return data_doc
521
+
522
+
523
+ def page_data_to_segments(result_text=None, chunk_size=None):
524
+
525
+ if not chunk_size:
526
+ chunk_size = 100
527
+
528
+ segments_chunks = []
529
+ time_global = 0
530
+ for page, result_data in result_text.items():
531
+ # result_image = result_data["images"]
532
+ result_text = result_data["text"]
533
+ text_chunks = split_text_into_chunks(result_text, chunk_size)
534
+ if not text_chunks:
535
+ text_chunks = [" "]
536
+
537
+ for chunk in text_chunks:
538
+ chunk_dict = {
539
+ "text": chunk,
540
+ "start": (1.0 + time_global),
541
+ "end": (2.0 + time_global),
542
+ "speaker": "SPEAKER_00",
543
+ "page": page,
544
+ }
545
+ segments_chunks.append(chunk_dict)
546
+ time_global += 1
547
+
548
+ result_diarize = {"segments": segments_chunks}
549
+
550
+ return result_diarize
551
+
552
+
553
+ def update_page_data(result_diarize, doc_data):
554
+ complete_text = ""
555
+ current_page = result_diarize["segments"][0]["page"]
556
+ text_page = ""
557
+
558
+ for seg in result_diarize["segments"]:
559
+ text = seg["text"] + " " # issue
560
+ complete_text += text
561
+
562
+ page = seg["page"]
563
+
564
+ if page == current_page:
565
+ text_page += text
566
+ else:
567
+ doc_data[current_page]["text"] = text_page
568
+
569
+ # Next
570
+ text_page = text
571
+ current_page = page
572
+
573
+ if doc_data[current_page]["text"] != text_page:
574
+ doc_data[current_page]["text"] = text_page
575
+
576
+ return doc_data
577
+
578
+
579
+ def fix_timestamps_docs(result_diarize, audio_files):
580
+ current_start = 0.0
581
+
582
+ for seg, audio in zip(result_diarize["segments"], audio_files):
583
+ duration = round(sf.info(audio).duration, 2)
584
+
585
+ seg["start"] = current_start
586
+ current_start += duration
587
+ seg["end"] = current_start
588
+
589
+ return result_diarize
590
+
591
+
592
+ def create_video_from_images(
593
+ doc_data,
594
+ result_diarize
595
+ ):
596
+
597
+ # First image path
598
+ first_image = "pdf_images/0000_00_aaa.png"
599
+
600
+ # Time segments and images
601
+ max_pages_idx = len(doc_data) - 1
602
+ current_page = result_diarize["segments"][0]["page"]
603
+ duration_page = 0.0
604
+ last_image = None
605
+
606
+ for seg in result_diarize["segments"]:
607
+ start = seg["start"]
608
+ end = seg["end"]
609
+ duration_seg = end - start
610
+
611
+ page = seg["page"]
612
+
613
+ if page == current_page:
614
+ duration_page += duration_seg
615
+ else:
616
+
617
+ images = doc_data[current_page]["images"]
618
+
619
+ if first_image:
620
+ images = [first_image] + images
621
+ first_image = None
622
+ if not doc_data[min(max_pages_idx, (current_page+1))]["text"].strip():
623
+ images = images + doc_data[min(max_pages_idx, (current_page+1))]["images"]
624
+ if not images and last_image:
625
+ images = [last_image]
626
+
627
+ # Calculate images duration
628
+ time_duration_per_image = round((duration_page / len(images)), 2)
629
+ doc_data[current_page]["time_per_image"] = time_duration_per_image
630
+
631
+ # Next values
632
+ doc_data[current_page]["images"] = images
633
+ last_image = images[-1]
634
+ duration_page = duration_seg
635
+ current_page = page
636
+
637
+ if "time_per_image" not in doc_data[current_page].keys():
638
+ images = doc_data[current_page]["images"]
639
+ if first_image:
640
+ images = [first_image] + images
641
+ if not images:
642
+ images = [last_image]
643
+ time_duration_per_image = round((duration_page / len(images)), 2)
644
+ doc_data[current_page]["time_per_image"] = time_duration_per_image
645
+
646
+ # Timestamped image video.
647
+ with open("list.txt", "w") as file:
648
+
649
+ for i, page in enumerate(doc_data.values()):
650
+
651
+ duration = page["time_per_image"]
652
+ for img in page["images"]:
653
+ if i == len(doc_data) - 1 and img == page["images"][-1]: # Check if it's the last item
654
+ file.write(f"file {img}\n")
655
+ file.write(f"outpoint {duration}")
656
+ else:
657
+ file.write(f"file {img}\n")
658
+ file.write(f"outpoint {duration}\n")
659
+
660
+ out_video = "video_from_images.mp4"
661
+ remove_files(out_video)
662
+
663
+ cm = f"ffmpeg -y -f concat -i list.txt -c:v libx264 -preset veryfast -crf 18 -pix_fmt yuv420p {out_video}"
664
+ cm_alt = f"ffmpeg -f concat -i list.txt -c:v libx264 -r 30 -pix_fmt yuv420p -y {out_video}"
665
+ try:
666
+ run_command(cm)
667
+ except Exception as error:
668
+ logger.error(str(error))
669
+ remove_files(out_video)
670
+ run_command(cm_alt)
671
+
672
+ return out_video
673
+
674
+
675
+ def merge_video_and_audio(video_doc, final_wav_file):
676
+
677
+ fixed_audio = "fixed_audio.mp3"
678
+ remove_files(fixed_audio)
679
+ cm = f"ffmpeg -i {final_wav_file} -c:a libmp3lame {fixed_audio}"
680
+ run_command(cm)
681
+
682
+ vid_out = "video_book.mp4"
683
+ remove_files(vid_out)
684
+ cm = f"ffmpeg -i {video_doc} -i {fixed_audio} -c:v copy -c:a copy -map 0:v -map 1:a -shortest {vid_out}"
685
+ run_command(cm)
686
+
687
+ return vid_out
688
+
689
+
690
+ # subtitles
691
+
692
+
693
+ def get_subtitle(
694
+ language,
695
+ segments_data,
696
+ extension,
697
+ filename=None,
698
+ highlight_words=False,
699
+ ):
700
+ if not filename:
701
+ filename = "task_subtitle"
702
+
703
+ is_ass_extension = False
704
+ if extension == "ass":
705
+ is_ass_extension = True
706
+ extension = "srt"
707
+
708
+ sub_file = filename + "." + extension
709
+ support_name = filename + ".mp3"
710
+ remove_files(sub_file)
711
+
712
+ writer = get_writer(extension, output_dir=".")
713
+ word_options = {
714
+ "highlight_words": highlight_words,
715
+ "max_line_count": None,
716
+ "max_line_width": None,
717
+ }
718
+
719
+ # Get data subs
720
+ subtitle_data = copy.deepcopy(segments_data)
721
+ subtitle_data["language"] = (
722
+ "ja" if language in ["ja", "zh", "zh-TW"] else language
723
+ )
724
+
725
+ # Clean
726
+ if not highlight_words:
727
+ subtitle_data.pop("word_segments", None)
728
+ for segment in subtitle_data["segments"]:
729
+ for key in ["speaker", "chars", "words"]:
730
+ segment.pop(key, None)
731
+
732
+ writer(
733
+ subtitle_data,
734
+ support_name,
735
+ word_options,
736
+ )
737
+
738
+ if is_ass_extension:
739
+ temp_name = filename + ".ass"
740
+ remove_files(temp_name)
741
+ convert_sub = f'ffmpeg -i "{sub_file}" "{temp_name}" -y'
742
+ run_command(convert_sub)
743
+ sub_file = temp_name
744
+
745
+ return sub_file
746
+
747
+
748
+ def process_subtitles(
749
+ deep_copied_result,
750
+ align_language,
751
+ result_diarize,
752
+ output_format_subtitle,
753
+ TRANSLATE_AUDIO_TO,
754
+ ):
755
+ name_ori = "sub_ori."
756
+ name_tra = "sub_tra."
757
+ remove_files(
758
+ [name_ori + output_format_subtitle, name_tra + output_format_subtitle]
759
+ )
760
+
761
+ writer = get_writer(output_format_subtitle, output_dir=".")
762
+ word_options = {
763
+ "highlight_words": False,
764
+ "max_line_count": None,
765
+ "max_line_width": None,
766
+ }
767
+
768
+ # original lang
769
+ subs_copy_result = copy.deepcopy(deep_copied_result)
770
+ subs_copy_result["language"] = (
771
+ "zh" if align_language == "zh-TW" else align_language
772
+ )
773
+ for segment in subs_copy_result["segments"]:
774
+ segment.pop("speaker", None)
775
+
776
+ try:
777
+ writer(
778
+ subs_copy_result,
779
+ name_ori[:-1] + ".mp3",
780
+ word_options,
781
+ )
782
+ except Exception as error:
783
+ logger.error(str(error))
784
+ if str(error) == "list indices must be integers or slices, not str":
785
+ logger.error(
786
+ "Related to poor word segmentation"
787
+ " in segments after alignment."
788
+ )
789
+ subs_copy_result["segments"][0].pop("words")
790
+ writer(
791
+ subs_copy_result,
792
+ name_ori[:-1] + ".mp3",
793
+ word_options,
794
+ )
795
+
796
+ # translated lang
797
+ subs_tra_copy_result = copy.deepcopy(result_diarize)
798
+ subs_tra_copy_result["language"] = (
799
+ "ja" if TRANSLATE_AUDIO_TO in ["ja", "zh", "zh-TW"] else align_language
800
+ )
801
+ subs_tra_copy_result.pop("word_segments", None)
802
+ for segment in subs_tra_copy_result["segments"]:
803
+ for key in ["speaker", "chars", "words"]:
804
+ segment.pop(key, None)
805
+
806
+ writer(
807
+ subs_tra_copy_result,
808
+ name_tra[:-1] + ".mp3",
809
+ word_options,
810
+ )
811
+
812
+ return name_tra + output_format_subtitle
813
+
814
+
815
+ def linguistic_level_segments(
816
+ result_base,
817
+ linguistic_unit="word", # word or char
818
+ ):
819
+ linguistic_unit = linguistic_unit[:4]
820
+ linguistic_unit_key = linguistic_unit + "s"
821
+ result = copy.deepcopy(result_base)
822
+
823
+ if linguistic_unit_key not in result["segments"][0].keys():
824
+ raise ValueError("No alignment detected, can't process")
825
+
826
+ segments_by_unit = []
827
+ for segment in result["segments"]:
828
+ segment_units = segment[linguistic_unit_key]
829
+ # segment_speaker = segment.get("speaker", "SPEAKER_00")
830
+
831
+ for unit in segment_units:
832
+
833
+ text = unit[linguistic_unit]
834
+
835
+ if "start" in unit.keys():
836
+ segments_by_unit.append(
837
+ {
838
+ "start": unit["start"],
839
+ "end": unit["end"],
840
+ "text": text,
841
+ # "speaker": segment_speaker,
842
+ }
843
+ )
844
+ elif not segments_by_unit:
845
+ pass
846
+ else:
847
+ segments_by_unit[-1]["text"] += text
848
+
849
+ return {"segments": segments_by_unit}
850
+
851
+
852
+ def break_aling_segments(
853
+ result: dict,
854
+ break_characters: str = "", # ":|,|.|"
855
+ ):
856
+ result_align = copy.deepcopy(result)
857
+
858
+ break_characters_list = break_characters.split("|")
859
+ break_characters_list = [i for i in break_characters_list if i != '']
860
+
861
+ if not break_characters_list:
862
+ logger.info("No valid break characters were specified.")
863
+ return result
864
+
865
+ logger.info(f"Redivide text segments by: {str(break_characters_list)}")
866
+
867
+ # create new with filters
868
+ normal = []
869
+
870
+ def process_chars(chars, letter_new_start, num, text):
871
+ start_key, end_key = "start", "end"
872
+ start_value = end_value = None
873
+
874
+ for char in chars:
875
+ if start_key in char:
876
+ start_value = char[start_key]
877
+ break
878
+
879
+ for char in reversed(chars):
880
+ if end_key in char:
881
+ end_value = char[end_key]
882
+ break
883
+
884
+ if not start_value or not end_value:
885
+ raise Exception(
886
+ f"Unable to obtain a valid timestamp for chars: {str(chars)}"
887
+ )
888
+
889
+ return {
890
+ "start": start_value,
891
+ "end": end_value,
892
+ "text": text,
893
+ "words": chars,
894
+ }
895
+
896
+ for i, segment in enumerate(result_align['segments']):
897
+
898
+ logger.debug(f"- Process segment: {i}, text: {segment['text']}")
899
+ # start = segment['start']
900
+ letter_new_start = 0
901
+ for num, char in enumerate(segment['chars']):
902
+
903
+ if char["char"] is None:
904
+ continue
905
+
906
+ # if "start" in char:
907
+ # start = char["start"]
908
+
909
+ # if "end" in char:
910
+ # end = char["end"]
911
+
912
+ # Break by character
913
+ if char['char'] in break_characters_list:
914
+
915
+ text = segment['text'][letter_new_start:num+1]
916
+
917
+ logger.debug(
918
+ f"Break in: {char['char']}, position: {num}, text: {text}"
919
+ )
920
+
921
+ chars = segment['chars'][letter_new_start:num+1]
922
+
923
+ if not text:
924
+ logger.debug("No text")
925
+ continue
926
+
927
+ if num == 0 and not text.strip():
928
+ logger.debug("blank space in start")
929
+ continue
930
+
931
+ if len(text) == 1:
932
+ logger.debug(f"Short char append, num: {num}")
933
+ normal[-1]["text"] += text
934
+ normal[-1]["words"].append(chars)
935
+ continue
936
+
937
+ # logger.debug(chars)
938
+ normal_dict = process_chars(chars, letter_new_start, num, text)
939
+
940
+ letter_new_start = num+1
941
+
942
+ normal.append(normal_dict)
943
+
944
+ # If we reach the end of the segment, add the last part of chars.
945
+ if num == len(segment["chars"]) - 1:
946
+
947
+ text = segment['text'][letter_new_start:num+1]
948
+
949
+ # If remain text len is not default len text
950
+ if num not in [len(text)-1, len(text)] and text:
951
+ logger.debug(f'Remaining text: {text}')
952
+
953
+ if not text:
954
+ logger.debug("No remaining text.")
955
+ continue
956
+
957
+ if len(text) == 1:
958
+ logger.debug(f"Short char append, num: {num}")
959
+ normal[-1]["text"] += text
960
+ normal[-1]["words"].append(chars)
961
+ continue
962
+
963
+ chars = segment['chars'][letter_new_start:num+1]
964
+
965
+ normal_dict = process_chars(chars, letter_new_start, num, text)
966
+
967
+ letter_new_start = num+1
968
+
969
+ normal.append(normal_dict)
970
+
971
+ # Rename char to word
972
+ for item in normal:
973
+ words_list = item['words']
974
+ for word_item in words_list:
975
+ if 'char' in word_item:
976
+ word_item['word'] = word_item.pop('char')
977
+
978
+ # Convert to dict default
979
+ break_segments = {"segments": normal}
980
+
981
+ msg_count = (
982
+ f"Segment count before: {len(result['segments'])}, "
983
+ f"after: {len(break_segments['segments'])}."
984
+ )
985
+ logger.info(msg_count)
986
+
987
+ return break_segments
soni_translate/text_to_speech.py ADDED
@@ -0,0 +1,1574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gtts import gTTS
2
+ import edge_tts, asyncio, json, glob # noqa
3
+ from tqdm import tqdm
4
+ import librosa, os, re, torch, gc, subprocess # noqa
5
+ from .language_configuration import (
6
+ fix_code_language,
7
+ BARK_VOICES_LIST,
8
+ VITS_VOICES_LIST,
9
+ )
10
+ from .utils import (
11
+ download_manager,
12
+ create_directories,
13
+ copy_files,
14
+ rename_file,
15
+ remove_directory_contents,
16
+ remove_files,
17
+ run_command,
18
+ )
19
+ import numpy as np
20
+ from typing import Any, Dict
21
+ from pathlib import Path
22
+ import soundfile as sf
23
+ import platform
24
+ import logging
25
+ import traceback
26
+ from .logging_setup import logger
27
+
28
+
29
+ class TTS_OperationError(Exception):
30
+ def __init__(self, message="The operation did not complete successfully."):
31
+ self.message = message
32
+ super().__init__(self.message)
33
+
34
+
35
+ def verify_saved_file_and_size(filename):
36
+ if not os.path.exists(filename):
37
+ raise TTS_OperationError(f"File '{filename}' was not saved.")
38
+ if os.path.getsize(filename) == 0:
39
+ raise TTS_OperationError(
40
+ f"File '{filename}' has a zero size. "
41
+ "Related to incorrect TTS for the target language"
42
+ )
43
+
44
+
45
+ def error_handling_in_tts(error, segment, TRANSLATE_AUDIO_TO, filename):
46
+ traceback.print_exc()
47
+ logger.error(f"Error: {str(error)}")
48
+ try:
49
+ from tempfile import TemporaryFile
50
+
51
+ tts = gTTS(segment["text"], lang=fix_code_language(TRANSLATE_AUDIO_TO))
52
+ # tts.save(filename)
53
+ f = TemporaryFile()
54
+ tts.write_to_fp(f)
55
+
56
+ # Reset the file pointer to the beginning of the file
57
+ f.seek(0)
58
+
59
+ # Read audio data from the TemporaryFile using soundfile
60
+ audio_data, samplerate = sf.read(f)
61
+ f.close() # Close the TemporaryFile
62
+ sf.write(
63
+ filename, audio_data, samplerate, format="ogg", subtype="vorbis"
64
+ )
65
+
66
+ logger.warning(
67
+ 'TTS auxiliary will be utilized '
68
+ f'rather than TTS: {segment["tts_name"]}'
69
+ )
70
+ verify_saved_file_and_size(filename)
71
+ except Exception as error:
72
+ logger.critical(f"Error: {str(error)}")
73
+ sample_rate_aux = 22050
74
+ duration = float(segment["end"]) - float(segment["start"])
75
+ data = np.zeros(int(sample_rate_aux * duration)).astype(np.float32)
76
+ sf.write(
77
+ filename, data, sample_rate_aux, format="ogg", subtype="vorbis"
78
+ )
79
+ logger.error("Audio will be replaced -> [silent audio].")
80
+ verify_saved_file_and_size(filename)
81
+
82
+
83
+ def pad_array(array, sr):
84
+
85
+ if isinstance(array, list):
86
+ array = np.array(array)
87
+
88
+ if not array.shape[0]:
89
+ raise ValueError("The generated audio does not contain any data")
90
+
91
+ valid_indices = np.where(np.abs(array) > 0.001)[0]
92
+
93
+ if len(valid_indices) == 0:
94
+ logger.debug(f"No valid indices: {array}")
95
+ return array
96
+
97
+ try:
98
+ pad_indice = int(0.1 * sr)
99
+ start_pad = max(0, valid_indices[0] - pad_indice)
100
+ end_pad = min(len(array), valid_indices[-1] + 1 + pad_indice)
101
+ padded_array = array[start_pad:end_pad]
102
+ return padded_array
103
+ except Exception as error:
104
+ logger.error(str(error))
105
+ return array
106
+
107
+
108
+ # =====================================
109
+ # EDGE TTS
110
+ # =====================================
111
+
112
+
113
+ def edge_tts_voices_list():
114
+ try:
115
+ completed_process = subprocess.run(
116
+ ["edge-tts", "--list-voices"], capture_output=True, text=True
117
+ )
118
+ lines = completed_process.stdout.strip().split("\n")
119
+ except Exception as error:
120
+ logger.debug(str(error))
121
+ lines = []
122
+
123
+ voices = []
124
+ for line in lines:
125
+ if line.startswith("Name: "):
126
+ voice_entry = {}
127
+ voice_entry["Name"] = line.split(": ")[1]
128
+ elif line.startswith("Gender: "):
129
+ voice_entry["Gender"] = line.split(": ")[1]
130
+ voices.append(voice_entry)
131
+
132
+ formatted_voices = [
133
+ f"{entry['Name']}-{entry['Gender']}" for entry in voices
134
+ ]
135
+
136
+ if not formatted_voices:
137
+ logger.warning(
138
+ "The list of Edge TTS voices could not be obtained, "
139
+ "switching to an alternative method"
140
+ )
141
+ tts_voice_list = asyncio.new_event_loop().run_until_complete(
142
+ edge_tts.list_voices()
143
+ )
144
+ formatted_voices = sorted(
145
+ [f"{v['ShortName']}-{v['Gender']}" for v in tts_voice_list]
146
+ )
147
+
148
+ if not formatted_voices:
149
+ logger.error("Can't get EDGE TTS - list voices")
150
+
151
+ return formatted_voices
152
+
153
+
154
+ def segments_egde_tts(filtered_edge_segments, TRANSLATE_AUDIO_TO, is_gui):
155
+ for segment in tqdm(filtered_edge_segments["segments"]):
156
+ speaker = segment["speaker"] # noqa
157
+ text = segment["text"]
158
+ start = segment["start"]
159
+ tts_name = segment["tts_name"]
160
+
161
+ # make the tts audio
162
+ filename = f"audio/{start}.ogg"
163
+ temp_file = filename[:-3] + "mp3"
164
+
165
+ logger.info(f"{text} >> {filename}")
166
+ try:
167
+ if is_gui:
168
+ asyncio.run(
169
+ edge_tts.Communicate(
170
+ text, "-".join(tts_name.split("-")[:-1])
171
+ ).save(temp_file)
172
+ )
173
+ else:
174
+ # nest_asyncio.apply() if not is_gui else None
175
+ command = f'edge-tts -t "{text}" -v "{tts_name.replace("-Male", "").replace("-Female", "")}" --write-media "{temp_file}"'
176
+ run_command(command)
177
+ verify_saved_file_and_size(temp_file)
178
+
179
+ data, sample_rate = sf.read(temp_file)
180
+ data = pad_array(data, sample_rate)
181
+ # os.remove(temp_file)
182
+
183
+ # Save file
184
+ sf.write(
185
+ file=filename,
186
+ samplerate=sample_rate,
187
+ data=data,
188
+ format="ogg",
189
+ subtype="vorbis",
190
+ )
191
+ verify_saved_file_and_size(filename)
192
+
193
+ except Exception as error:
194
+ error_handling_in_tts(error, segment, TRANSLATE_AUDIO_TO, filename)
195
+
196
+
197
+ # =====================================
198
+ # BARK TTS
199
+ # =====================================
200
+
201
+
202
+ def segments_bark_tts(
203
+ filtered_bark_segments, TRANSLATE_AUDIO_TO, model_id_bark="suno/bark-small"
204
+ ):
205
+ from transformers import AutoProcessor, BarkModel
206
+ from optimum.bettertransformer import BetterTransformer
207
+
208
+ device = os.environ.get("SONITR_DEVICE")
209
+ torch_dtype_env = torch.float16 if device == "cuda" else torch.float32
210
+
211
+ # load model bark
212
+ model = BarkModel.from_pretrained(
213
+ model_id_bark, torch_dtype=torch_dtype_env
214
+ ).to(device)
215
+ model = model.to(device)
216
+ processor = AutoProcessor.from_pretrained(
217
+ model_id_bark, return_tensors="pt"
218
+ ) # , padding=True
219
+ if device == "cuda":
220
+ # convert to bettertransformer
221
+ model = BetterTransformer.transform(model, keep_original_model=False)
222
+ # enable CPU offload
223
+ # model.enable_cpu_offload()
224
+ sampling_rate = model.generation_config.sample_rate
225
+
226
+ # filtered_segments = filtered_bark_segments['segments']
227
+ # Sorting the segments by 'tts_name'
228
+ # sorted_segments = sorted(filtered_segments, key=lambda x: x['tts_name'])
229
+ # logger.debug(sorted_segments)
230
+
231
+ for segment in tqdm(filtered_bark_segments["segments"]):
232
+ speaker = segment["speaker"] # noqa
233
+ text = segment["text"]
234
+ start = segment["start"]
235
+ tts_name = segment["tts_name"]
236
+
237
+ inputs = processor(text, voice_preset=BARK_VOICES_LIST[tts_name]).to(
238
+ device
239
+ )
240
+
241
+ # make the tts audio
242
+ filename = f"audio/{start}.ogg"
243
+ logger.info(f"{text} >> {filename}")
244
+ try:
245
+ # Infer
246
+ with torch.inference_mode():
247
+ speech_output = model.generate(
248
+ **inputs,
249
+ do_sample=True,
250
+ fine_temperature=0.4,
251
+ coarse_temperature=0.8,
252
+ pad_token_id=processor.tokenizer.pad_token_id,
253
+ )
254
+ # Save file
255
+ data_tts = pad_array(
256
+ speech_output.cpu().numpy().squeeze().astype(np.float32),
257
+ sampling_rate,
258
+ )
259
+ sf.write(
260
+ file=filename,
261
+ samplerate=sampling_rate,
262
+ data=data_tts,
263
+ format="ogg",
264
+ subtype="vorbis",
265
+ )
266
+ verify_saved_file_and_size(filename)
267
+ except Exception as error:
268
+ error_handling_in_tts(error, segment, TRANSLATE_AUDIO_TO, filename)
269
+ gc.collect()
270
+ torch.cuda.empty_cache()
271
+ try:
272
+ del processor
273
+ del model
274
+ gc.collect()
275
+ torch.cuda.empty_cache()
276
+ except Exception as error:
277
+ logger.error(str(error))
278
+ gc.collect()
279
+ torch.cuda.empty_cache()
280
+
281
+
282
+ # =====================================
283
+ # VITS TTS
284
+ # =====================================
285
+
286
+
287
+ def uromanize(input_string):
288
+ """Convert non-Roman strings to Roman using the `uroman` perl package."""
289
+ # script_path = os.path.join(uroman_path, "bin", "uroman.pl")
290
+
291
+ if not os.path.exists("./uroman"):
292
+ logger.info(
293
+ "Clonning repository uroman https://github.com/isi-nlp/uroman.git"
294
+ " for romanize the text"
295
+ )
296
+ process = subprocess.Popen(
297
+ ["git", "clone", "https://github.com/isi-nlp/uroman.git"],
298
+ stdout=subprocess.PIPE,
299
+ stderr=subprocess.PIPE,
300
+ )
301
+ stdout, stderr = process.communicate()
302
+ script_path = os.path.join("./uroman", "bin", "uroman.pl")
303
+
304
+ command = ["perl", script_path]
305
+
306
+ process = subprocess.Popen(
307
+ command,
308
+ stdin=subprocess.PIPE,
309
+ stdout=subprocess.PIPE,
310
+ stderr=subprocess.PIPE,
311
+ )
312
+ # Execute the perl command
313
+ stdout, stderr = process.communicate(input=input_string.encode())
314
+
315
+ if process.returncode != 0:
316
+ raise ValueError(f"Error {process.returncode}: {stderr.decode()}")
317
+
318
+ # Return the output as a string and skip the new-line character at the end
319
+ return stdout.decode()[:-1]
320
+
321
+
322
+ def segments_vits_tts(filtered_vits_segments, TRANSLATE_AUDIO_TO):
323
+ from transformers import VitsModel, AutoTokenizer
324
+
325
+ filtered_segments = filtered_vits_segments["segments"]
326
+ # Sorting the segments by 'tts_name'
327
+ sorted_segments = sorted(filtered_segments, key=lambda x: x["tts_name"])
328
+ logger.debug(sorted_segments)
329
+
330
+ model_name_key = None
331
+ for segment in tqdm(sorted_segments):
332
+ speaker = segment["speaker"] # noqa
333
+ text = segment["text"]
334
+ start = segment["start"]
335
+ tts_name = segment["tts_name"]
336
+
337
+ if tts_name != model_name_key:
338
+ model_name_key = tts_name
339
+ model = VitsModel.from_pretrained(VITS_VOICES_LIST[tts_name])
340
+ tokenizer = AutoTokenizer.from_pretrained(
341
+ VITS_VOICES_LIST[tts_name]
342
+ )
343
+ sampling_rate = model.config.sampling_rate
344
+
345
+ if tokenizer.is_uroman:
346
+ romanize_text = uromanize(text)
347
+ logger.debug(f"Romanize text: {romanize_text}")
348
+ inputs = tokenizer(romanize_text, return_tensors="pt")
349
+ else:
350
+ inputs = tokenizer(text, return_tensors="pt")
351
+
352
+ # make the tts audio
353
+ filename = f"audio/{start}.ogg"
354
+ logger.info(f"{text} >> {filename}")
355
+ try:
356
+ # Infer
357
+ with torch.no_grad():
358
+ speech_output = model(**inputs).waveform
359
+
360
+ data_tts = pad_array(
361
+ speech_output.cpu().numpy().squeeze().astype(np.float32),
362
+ sampling_rate,
363
+ )
364
+ # Save file
365
+ sf.write(
366
+ file=filename,
367
+ samplerate=sampling_rate,
368
+ data=data_tts,
369
+ format="ogg",
370
+ subtype="vorbis",
371
+ )
372
+ verify_saved_file_and_size(filename)
373
+ except Exception as error:
374
+ error_handling_in_tts(error, segment, TRANSLATE_AUDIO_TO, filename)
375
+ gc.collect()
376
+ torch.cuda.empty_cache()
377
+ try:
378
+ del tokenizer
379
+ del model
380
+ gc.collect()
381
+ torch.cuda.empty_cache()
382
+ except Exception as error:
383
+ logger.error(str(error))
384
+ gc.collect()
385
+ torch.cuda.empty_cache()
386
+
387
+
388
+ # =====================================
389
+ # Coqui XTTS
390
+ # =====================================
391
+
392
+
393
+ def coqui_xtts_voices_list():
394
+ main_folder = "_XTTS_"
395
+ pattern_coqui = re.compile(r".+\.(wav|mp3|ogg|m4a)$")
396
+ pattern_automatic_speaker = re.compile(r"AUTOMATIC_SPEAKER_\d+\.wav$")
397
+
398
+ # List only files in the directory matching the pattern but not matching
399
+ # AUTOMATIC_SPEAKER_00.wav, AUTOMATIC_SPEAKER_01.wav, etc.
400
+ wav_voices = [
401
+ "_XTTS_/" + f
402
+ for f in os.listdir(main_folder)
403
+ if os.path.isfile(os.path.join(main_folder, f))
404
+ and pattern_coqui.match(f)
405
+ and not pattern_automatic_speaker.match(f)
406
+ ]
407
+
408
+ return ["_XTTS_/AUTOMATIC.wav"] + wav_voices
409
+
410
+
411
+ def seconds_to_hhmmss_ms(seconds):
412
+ hours = seconds // 3600
413
+ minutes = (seconds % 3600) // 60
414
+ seconds = seconds % 60
415
+ milliseconds = int((seconds - int(seconds)) * 1000)
416
+ return "%02d:%02d:%02d.%03d" % (hours, minutes, int(seconds), milliseconds)
417
+
418
+
419
+ def audio_trimming(audio_path, destination, start, end):
420
+ if isinstance(start, (int, float)):
421
+ start = seconds_to_hhmmss_ms(start)
422
+ if isinstance(end, (int, float)):
423
+ end = seconds_to_hhmmss_ms(end)
424
+
425
+ if destination:
426
+ file_directory = destination
427
+ else:
428
+ file_directory = os.path.dirname(audio_path)
429
+
430
+ file_name = os.path.splitext(os.path.basename(audio_path))[0]
431
+ file_ = f"{file_name}_trim.wav"
432
+ # file_ = f'{os.path.splitext(audio_path)[0]}_trim.wav'
433
+ output_path = os.path.join(file_directory, file_)
434
+
435
+ # -t (duration from -ss) | -to (time stop) | -af silenceremove=1:0:-50dB (remove silence)
436
+ command = f'ffmpeg -y -loglevel error -i "{audio_path}" -ss {start} -to {end} -acodec pcm_s16le -f wav "{output_path}"'
437
+ run_command(command)
438
+
439
+ return output_path
440
+
441
+
442
+ def convert_to_xtts_good_sample(audio_path: str = "", destination: str = ""):
443
+ if destination:
444
+ file_directory = destination
445
+ else:
446
+ file_directory = os.path.dirname(audio_path)
447
+
448
+ file_name = os.path.splitext(os.path.basename(audio_path))[0]
449
+ file_ = f"{file_name}_good_sample.wav"
450
+ # file_ = f'{os.path.splitext(audio_path)[0]}_good_sample.wav'
451
+ mono_path = os.path.join(file_directory, file_) # get root
452
+
453
+ command = f'ffmpeg -y -loglevel error -i "{audio_path}" -ac 1 -ar 22050 -sample_fmt s16 -f wav "{mono_path}"'
454
+ run_command(command)
455
+
456
+ return mono_path
457
+
458
+
459
+ def sanitize_file_name(file_name):
460
+ import unicodedata
461
+
462
+ # Normalize the string to NFKD form to separate combined characters into
463
+ # base characters and diacritics
464
+ normalized_name = unicodedata.normalize("NFKD", file_name)
465
+ # Replace any non-ASCII characters or special symbols with an underscore
466
+ sanitized_name = re.sub(r"[^\w\s.-]", "_", normalized_name)
467
+ return sanitized_name
468
+
469
+
470
+ def create_wav_file_vc(
471
+ sample_name="", # name final file
472
+ audio_wav="", # path
473
+ start=None, # trim start
474
+ end=None, # trim end
475
+ output_final_path="_XTTS_",
476
+ get_vocals_dereverb=True,
477
+ ):
478
+ sample_name = sample_name if sample_name else "default_name"
479
+ sample_name = sanitize_file_name(sample_name)
480
+ audio_wav = audio_wav if isinstance(audio_wav, str) else audio_wav.name
481
+
482
+ BASE_DIR = (
483
+ "." # os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
484
+ )
485
+
486
+ output_dir = os.path.join(BASE_DIR, "clean_song_output") # remove content
487
+ # remove_directory_contents(output_dir)
488
+
489
+ if start or end:
490
+ # Cut file
491
+ audio_segment = audio_trimming(audio_wav, output_dir, start, end)
492
+ else:
493
+ # Complete file
494
+ audio_segment = audio_wav
495
+
496
+ from .mdx_net import process_uvr_task
497
+
498
+ try:
499
+ _, _, _, _, audio_segment = process_uvr_task(
500
+ orig_song_path=audio_segment,
501
+ main_vocals=True,
502
+ dereverb=get_vocals_dereverb,
503
+ )
504
+ except Exception as error:
505
+ logger.error(str(error))
506
+
507
+ sample = convert_to_xtts_good_sample(audio_segment)
508
+
509
+ sample_name = f"{sample_name}.wav"
510
+ sample_rename = rename_file(sample, sample_name)
511
+
512
+ copy_files(sample_rename, output_final_path)
513
+
514
+ final_sample = os.path.join(output_final_path, sample_name)
515
+ if os.path.exists(final_sample):
516
+ logger.info(final_sample)
517
+ return final_sample
518
+ else:
519
+ raise Exception(f"Error wav: {final_sample}")
520
+
521
+
522
+ def create_new_files_for_vc(
523
+ speakers_coqui,
524
+ segments_base,
525
+ dereverb_automatic=True
526
+ ):
527
+ # before function delete automatic delete_previous_automatic
528
+ output_dir = os.path.join(".", "clean_song_output") # remove content
529
+ remove_directory_contents(output_dir)
530
+
531
+ for speaker in speakers_coqui:
532
+ filtered_speaker = [
533
+ segment
534
+ for segment in segments_base
535
+ if segment["speaker"] == speaker
536
+ ]
537
+ if len(filtered_speaker) > 4:
538
+ filtered_speaker = filtered_speaker[1:]
539
+ if filtered_speaker[0]["tts_name"] == "_XTTS_/AUTOMATIC.wav":
540
+ name_automatic_wav = f"AUTOMATIC_{speaker}"
541
+ if os.path.exists(f"_XTTS_/{name_automatic_wav}.wav"):
542
+ logger.info(f"WAV automatic {speaker} exists")
543
+ # path_wav = path_automatic_wav
544
+ pass
545
+ else:
546
+ # create wav
547
+ wav_ok = False
548
+ for seg in filtered_speaker:
549
+ duration = float(seg["end"]) - float(seg["start"])
550
+ if duration > 7.0 and duration < 12.0:
551
+ logger.info(
552
+ f'Processing segment: {seg["start"]}, {seg["end"]}, {seg["speaker"]}, {duration}, {seg["text"]}'
553
+ )
554
+ create_wav_file_vc(
555
+ sample_name=name_automatic_wav,
556
+ audio_wav="audio.wav",
557
+ start=(float(seg["start"]) + 1.0),
558
+ end=(float(seg["end"]) - 1.0),
559
+ get_vocals_dereverb=dereverb_automatic,
560
+ )
561
+ wav_ok = True
562
+ break
563
+
564
+ if not wav_ok:
565
+ logger.info("Taking the first segment")
566
+ seg = filtered_speaker[0]
567
+ logger.info(
568
+ f'Processing segment: {seg["start"]}, {seg["end"]}, {seg["speaker"]}, {seg["text"]}'
569
+ )
570
+ max_duration = float(seg["end"]) - float(seg["start"])
571
+ max_duration = max(2.0, min(max_duration, 9.0))
572
+
573
+ create_wav_file_vc(
574
+ sample_name=name_automatic_wav,
575
+ audio_wav="audio.wav",
576
+ start=(float(seg["start"])),
577
+ end=(float(seg["start"]) + max_duration),
578
+ get_vocals_dereverb=dereverb_automatic,
579
+ )
580
+
581
+
582
+ def segments_coqui_tts(
583
+ filtered_coqui_segments,
584
+ TRANSLATE_AUDIO_TO,
585
+ model_id_coqui="tts_models/multilingual/multi-dataset/xtts_v2",
586
+ speakers_coqui=None,
587
+ delete_previous_automatic=True,
588
+ dereverb_automatic=True,
589
+ emotion=None,
590
+ ):
591
+ """XTTS
592
+ Install:
593
+ pip install -q TTS==0.21.1
594
+ pip install -q numpy==1.23.5
595
+
596
+ Notes:
597
+ - tts_name is the wav|mp3|ogg|m4a file for VC
598
+ """
599
+ from TTS.api import TTS
600
+
601
+ TRANSLATE_AUDIO_TO = fix_code_language(TRANSLATE_AUDIO_TO, syntax="coqui")
602
+ supported_lang_coqui = [
603
+ "zh-cn",
604
+ "en",
605
+ "fr",
606
+ "de",
607
+ "it",
608
+ "pt",
609
+ "pl",
610
+ "tr",
611
+ "ru",
612
+ "nl",
613
+ "cs",
614
+ "ar",
615
+ "es",
616
+ "hu",
617
+ "ko",
618
+ "ja",
619
+ ]
620
+ if TRANSLATE_AUDIO_TO not in supported_lang_coqui:
621
+ raise TTS_OperationError(
622
+ f"'{TRANSLATE_AUDIO_TO}' is not a supported language for Coqui XTTS"
623
+ )
624
+ # Emotion and speed can only be used with Coqui Studio models. discontinued
625
+ # emotions = ["Neutral", "Happy", "Sad", "Angry", "Dull"]
626
+
627
+ if delete_previous_automatic:
628
+ for spk in speakers_coqui:
629
+ remove_files(f"_XTTS_/AUTOMATIC_{spk}.wav")
630
+
631
+ directory_audios_vc = "_XTTS_"
632
+ create_directories(directory_audios_vc)
633
+ create_new_files_for_vc(
634
+ speakers_coqui,
635
+ filtered_coqui_segments["segments"],
636
+ dereverb_automatic,
637
+ )
638
+
639
+ # Init TTS
640
+ device = os.environ.get("SONITR_DEVICE")
641
+ model = TTS(model_id_coqui).to(device)
642
+ sampling_rate = 24000
643
+
644
+ # filtered_segments = filtered_coqui_segments['segments']
645
+ # Sorting the segments by 'tts_name'
646
+ # sorted_segments = sorted(filtered_segments, key=lambda x: x['tts_name'])
647
+ # logger.debug(sorted_segments)
648
+
649
+ for segment in tqdm(filtered_coqui_segments["segments"]):
650
+ speaker = segment["speaker"]
651
+ text = segment["text"]
652
+ start = segment["start"]
653
+ tts_name = segment["tts_name"]
654
+ if tts_name == "_XTTS_/AUTOMATIC.wav":
655
+ tts_name = f"_XTTS_/AUTOMATIC_{speaker}.wav"
656
+
657
+ # make the tts audio
658
+ filename = f"audio/{start}.ogg"
659
+ logger.info(f"{text} >> {filename}")
660
+ try:
661
+ # Infer
662
+ wav = model.tts(
663
+ text=text, speaker_wav=tts_name, language=TRANSLATE_AUDIO_TO
664
+ )
665
+ data_tts = pad_array(
666
+ wav,
667
+ sampling_rate,
668
+ )
669
+ # Save file
670
+ sf.write(
671
+ file=filename,
672
+ samplerate=sampling_rate,
673
+ data=data_tts,
674
+ format="ogg",
675
+ subtype="vorbis",
676
+ )
677
+ verify_saved_file_and_size(filename)
678
+ except Exception as error:
679
+ error_handling_in_tts(error, segment, TRANSLATE_AUDIO_TO, filename)
680
+ gc.collect()
681
+ torch.cuda.empty_cache()
682
+ try:
683
+ del model
684
+ gc.collect()
685
+ torch.cuda.empty_cache()
686
+ except Exception as error:
687
+ logger.error(str(error))
688
+ gc.collect()
689
+ torch.cuda.empty_cache()
690
+
691
+
692
+ # =====================================
693
+ # PIPER TTS
694
+ # =====================================
695
+
696
+
697
+ def piper_tts_voices_list():
698
+ file_path = download_manager(
699
+ url="https://huggingface.co/rhasspy/piper-voices/resolve/main/voices.json",
700
+ path="./PIPER_MODELS",
701
+ )
702
+
703
+ with open(file_path, "r", encoding="utf8") as file:
704
+ data = json.load(file)
705
+ piper_id_models = [key + " VITS-onnx" for key in data.keys()]
706
+
707
+ return piper_id_models
708
+
709
+
710
+ def replace_text_in_json(file_path, key_to_replace, new_text, condition=None):
711
+ # Read the JSON file
712
+ with open(file_path, "r", encoding="utf-8") as file:
713
+ data = json.load(file)
714
+
715
+ # Modify the specified key's value with the new text
716
+ if key_to_replace in data:
717
+ if condition:
718
+ value_condition = condition
719
+ else:
720
+ value_condition = data[key_to_replace]
721
+
722
+ if data[key_to_replace] == value_condition:
723
+ data[key_to_replace] = new_text
724
+
725
+ # Write the modified content back to the JSON file
726
+ with open(file_path, "w") as file:
727
+ json.dump(
728
+ data, file, indent=2
729
+ ) # Write the modified data back to the file with indentation for readability
730
+
731
+
732
+ def load_piper_model(
733
+ model: str,
734
+ data_dir: list,
735
+ download_dir: str = "",
736
+ update_voices: bool = False,
737
+ ):
738
+ from piper import PiperVoice
739
+ from piper.download import ensure_voice_exists, find_voice, get_voices
740
+
741
+ try:
742
+ import onnxruntime as rt
743
+
744
+ if rt.get_device() == "GPU" and os.environ.get("SONITR_DEVICE") == "cuda":
745
+ logger.debug("onnxruntime device > GPU")
746
+ cuda = True
747
+ else:
748
+ logger.info(
749
+ "onnxruntime device > CPU"
750
+ ) # try pip install onnxruntime-gpu
751
+ cuda = False
752
+ except Exception as error:
753
+ raise TTS_OperationError(f"onnxruntime error: {str(error)}")
754
+
755
+ # Disable CUDA in Windows
756
+ if platform.system() == "Windows":
757
+ logger.info("Employing CPU exclusivity with Piper TTS")
758
+ cuda = False
759
+
760
+ if not download_dir:
761
+ # Download to first data directory by default
762
+ download_dir = data_dir[0]
763
+ else:
764
+ data_dir = [os.path.join(data_dir[0], download_dir)]
765
+
766
+ # Download voice if file doesn't exist
767
+ model_path = Path(model)
768
+ if not model_path.exists():
769
+ # Load voice info
770
+ voices_info = get_voices(download_dir, update_voices=update_voices)
771
+
772
+ # Resolve aliases for backwards compatibility with old voice names
773
+ aliases_info: Dict[str, Any] = {}
774
+ for voice_info in voices_info.values():
775
+ for voice_alias in voice_info.get("aliases", []):
776
+ aliases_info[voice_alias] = {"_is_alias": True, **voice_info}
777
+
778
+ voices_info.update(aliases_info)
779
+ ensure_voice_exists(model, data_dir, download_dir, voices_info)
780
+ model, config = find_voice(model, data_dir)
781
+
782
+ replace_text_in_json(
783
+ config, "phoneme_type", "espeak", "PhonemeType.ESPEAK"
784
+ )
785
+
786
+ # Load voice
787
+ voice = PiperVoice.load(model, config_path=config, use_cuda=cuda)
788
+
789
+ return voice
790
+
791
+
792
+ def synthesize_text_to_audio_np_array(voice, text, synthesize_args):
793
+ audio_stream = voice.synthesize_stream_raw(text, **synthesize_args)
794
+
795
+ # Collect the audio bytes into a single NumPy array
796
+ audio_data = b""
797
+ for audio_bytes in audio_stream:
798
+ audio_data += audio_bytes
799
+
800
+ # Ensure correct data type and convert audio bytes to NumPy array
801
+ audio_np = np.frombuffer(audio_data, dtype=np.int16)
802
+ return audio_np
803
+
804
+
805
+ def segments_vits_onnx_tts(filtered_onnx_vits_segments, TRANSLATE_AUDIO_TO):
806
+ """
807
+ Install:
808
+ pip install -q piper-tts==1.2.0 onnxruntime-gpu # for cuda118
809
+ """
810
+
811
+ data_dir = [
812
+ str(Path.cwd())
813
+ ] # "Data directory to check for downloaded models (default: current directory)"
814
+ download_dir = "PIPER_MODELS"
815
+ # model_name = "en_US-lessac-medium" tts_name in a dict like VITS
816
+ update_voices = True # "Download latest voices.json during startup",
817
+
818
+ synthesize_args = {
819
+ "speaker_id": None,
820
+ "length_scale": 1.0,
821
+ "noise_scale": 0.667,
822
+ "noise_w": 0.8,
823
+ "sentence_silence": 0.0,
824
+ }
825
+
826
+ filtered_segments = filtered_onnx_vits_segments["segments"]
827
+ # Sorting the segments by 'tts_name'
828
+ sorted_segments = sorted(filtered_segments, key=lambda x: x["tts_name"])
829
+ logger.debug(sorted_segments)
830
+
831
+ model_name_key = None
832
+ for segment in tqdm(sorted_segments):
833
+ speaker = segment["speaker"] # noqa
834
+ text = segment["text"]
835
+ start = segment["start"]
836
+ tts_name = segment["tts_name"].replace(" VITS-onnx", "")
837
+
838
+ if tts_name != model_name_key:
839
+ model_name_key = tts_name
840
+ model = load_piper_model(
841
+ tts_name, data_dir, download_dir, update_voices
842
+ )
843
+ sampling_rate = model.config.sample_rate
844
+
845
+ # make the tts audio
846
+ filename = f"audio/{start}.ogg"
847
+ logger.info(f"{text} >> {filename}")
848
+ try:
849
+ # Infer
850
+ speech_output = synthesize_text_to_audio_np_array(
851
+ model, text, synthesize_args
852
+ )
853
+ data_tts = pad_array(
854
+ speech_output, # .cpu().numpy().squeeze().astype(np.float32),
855
+ sampling_rate,
856
+ )
857
+ # Save file
858
+ sf.write(
859
+ file=filename,
860
+ samplerate=sampling_rate,
861
+ data=data_tts,
862
+ format="ogg",
863
+ subtype="vorbis",
864
+ )
865
+ verify_saved_file_and_size(filename)
866
+ except Exception as error:
867
+ error_handling_in_tts(error, segment, TRANSLATE_AUDIO_TO, filename)
868
+ gc.collect()
869
+ torch.cuda.empty_cache()
870
+ try:
871
+ del model
872
+ gc.collect()
873
+ torch.cuda.empty_cache()
874
+ except Exception as error:
875
+ logger.error(str(error))
876
+ gc.collect()
877
+ torch.cuda.empty_cache()
878
+
879
+
880
+ # =====================================
881
+ # CLOSEAI TTS
882
+ # =====================================
883
+
884
+
885
+ def segments_openai_tts(
886
+ filtered_openai_tts_segments, TRANSLATE_AUDIO_TO
887
+ ):
888
+ from openai import OpenAI
889
+
890
+ client = OpenAI()
891
+ sampling_rate = 24000
892
+
893
+ # filtered_segments = filtered_openai_tts_segments['segments']
894
+ # Sorting the segments by 'tts_name'
895
+ # sorted_segments = sorted(filtered_segments, key=lambda x: x['tts_name'])
896
+
897
+ for segment in tqdm(filtered_openai_tts_segments["segments"]):
898
+ speaker = segment["speaker"] # noqa
899
+ text = segment["text"].strip()
900
+ start = segment["start"]
901
+ tts_name = segment["tts_name"]
902
+
903
+ # make the tts audio
904
+ filename = f"audio/{start}.ogg"
905
+ logger.info(f"{text} >> {filename}")
906
+
907
+ try:
908
+ # Request
909
+ response = client.audio.speech.create(
910
+ model="tts-1-hd" if "HD" in tts_name else "tts-1",
911
+ voice=tts_name.split()[0][1:],
912
+ response_format="wav",
913
+ input=text
914
+ )
915
+
916
+ audio_bytes = b''
917
+ for data in response.iter_bytes(chunk_size=4096):
918
+ audio_bytes += data
919
+
920
+ speech_output = np.frombuffer(audio_bytes, dtype=np.int16)
921
+
922
+ # Save file
923
+ data_tts = pad_array(
924
+ speech_output[240:],
925
+ sampling_rate,
926
+ )
927
+
928
+ sf.write(
929
+ file=filename,
930
+ samplerate=sampling_rate,
931
+ data=data_tts,
932
+ format="ogg",
933
+ subtype="vorbis",
934
+ )
935
+ verify_saved_file_and_size(filename)
936
+
937
+ except Exception as error:
938
+ error_handling_in_tts(error, segment, TRANSLATE_AUDIO_TO, filename)
939
+
940
+
941
+ # =====================================
942
+ # Select task TTS
943
+ # =====================================
944
+
945
+
946
+ def find_spkr(pattern, speaker_to_voice, segments):
947
+ return [
948
+ speaker
949
+ for speaker, voice in speaker_to_voice.items()
950
+ if pattern.match(voice) and any(
951
+ segment["speaker"] == speaker for segment in segments
952
+ )
953
+ ]
954
+
955
+
956
+ def filter_by_speaker(speakers, segments):
957
+ return {
958
+ "segments": [
959
+ segment
960
+ for segment in segments
961
+ if segment["speaker"] in speakers
962
+ ]
963
+ }
964
+
965
+
966
+ def audio_segmentation_to_voice(
967
+ result_diarize,
968
+ TRANSLATE_AUDIO_TO,
969
+ is_gui,
970
+ tts_voice00,
971
+ tts_voice01="",
972
+ tts_voice02="",
973
+ tts_voice03="",
974
+ tts_voice04="",
975
+ tts_voice05="",
976
+ tts_voice06="",
977
+ tts_voice07="",
978
+ tts_voice08="",
979
+ tts_voice09="",
980
+ tts_voice10="",
981
+ tts_voice11="",
982
+ dereverb_automatic=True,
983
+ model_id_bark="suno/bark-small",
984
+ model_id_coqui="tts_models/multilingual/multi-dataset/xtts_v2",
985
+ delete_previous_automatic=True,
986
+ ):
987
+
988
+ remove_directory_contents("audio")
989
+
990
+ # Mapping speakers to voice variables
991
+ speaker_to_voice = {
992
+ "SPEAKER_00": tts_voice00,
993
+ "SPEAKER_01": tts_voice01,
994
+ "SPEAKER_02": tts_voice02,
995
+ "SPEAKER_03": tts_voice03,
996
+ "SPEAKER_04": tts_voice04,
997
+ "SPEAKER_05": tts_voice05,
998
+ "SPEAKER_06": tts_voice06,
999
+ "SPEAKER_07": tts_voice07,
1000
+ "SPEAKER_08": tts_voice08,
1001
+ "SPEAKER_09": tts_voice09,
1002
+ "SPEAKER_10": tts_voice10,
1003
+ "SPEAKER_11": tts_voice11,
1004
+ }
1005
+
1006
+ # Assign 'SPEAKER_00' to segments without a 'speaker' key
1007
+ for segment in result_diarize["segments"]:
1008
+ if "speaker" not in segment:
1009
+ segment["speaker"] = "SPEAKER_00"
1010
+ logger.warning(
1011
+ "NO SPEAKER DETECT IN SEGMENT: First TTS will be used in the"
1012
+ f" segment time {segment['start'], segment['text']}"
1013
+ )
1014
+ # Assign the TTS name
1015
+ segment["tts_name"] = speaker_to_voice[segment["speaker"]]
1016
+
1017
+ # Find TTS method
1018
+ pattern_edge = re.compile(r".*-(Male|Female)$")
1019
+ pattern_bark = re.compile(r".* BARK$")
1020
+ pattern_vits = re.compile(r".* VITS$")
1021
+ pattern_coqui = re.compile(r".+\.(wav|mp3|ogg|m4a)$")
1022
+ pattern_vits_onnx = re.compile(r".* VITS-onnx$")
1023
+ pattern_openai_tts = re.compile(r".* OpenAI-TTS$")
1024
+
1025
+ all_segments = result_diarize["segments"]
1026
+
1027
+ speakers_edge = find_spkr(pattern_edge, speaker_to_voice, all_segments)
1028
+ speakers_bark = find_spkr(pattern_bark, speaker_to_voice, all_segments)
1029
+ speakers_vits = find_spkr(pattern_vits, speaker_to_voice, all_segments)
1030
+ speakers_coqui = find_spkr(pattern_coqui, speaker_to_voice, all_segments)
1031
+ speakers_vits_onnx = find_spkr(
1032
+ pattern_vits_onnx, speaker_to_voice, all_segments
1033
+ )
1034
+ speakers_openai_tts = find_spkr(
1035
+ pattern_openai_tts, speaker_to_voice, all_segments
1036
+ )
1037
+
1038
+ # Filter method in segments
1039
+ filtered_edge = filter_by_speaker(speakers_edge, all_segments)
1040
+ filtered_bark = filter_by_speaker(speakers_bark, all_segments)
1041
+ filtered_vits = filter_by_speaker(speakers_vits, all_segments)
1042
+ filtered_coqui = filter_by_speaker(speakers_coqui, all_segments)
1043
+ filtered_vits_onnx = filter_by_speaker(speakers_vits_onnx, all_segments)
1044
+ filtered_openai_tts = filter_by_speaker(speakers_openai_tts, all_segments)
1045
+
1046
+ # Infer
1047
+ if filtered_edge["segments"]:
1048
+ logger.info(f"EDGE TTS: {speakers_edge}")
1049
+ segments_egde_tts(filtered_edge, TRANSLATE_AUDIO_TO, is_gui) # mp3
1050
+ if filtered_bark["segments"]:
1051
+ logger.info(f"BARK TTS: {speakers_bark}")
1052
+ segments_bark_tts(
1053
+ filtered_bark, TRANSLATE_AUDIO_TO, model_id_bark
1054
+ ) # wav
1055
+ if filtered_vits["segments"]:
1056
+ logger.info(f"VITS TTS: {speakers_vits}")
1057
+ segments_vits_tts(filtered_vits, TRANSLATE_AUDIO_TO) # wav
1058
+ if filtered_coqui["segments"]:
1059
+ logger.info(f"Coqui TTS: {speakers_coqui}")
1060
+ segments_coqui_tts(
1061
+ filtered_coqui,
1062
+ TRANSLATE_AUDIO_TO,
1063
+ model_id_coqui,
1064
+ speakers_coqui,
1065
+ delete_previous_automatic,
1066
+ dereverb_automatic,
1067
+ ) # wav
1068
+ if filtered_vits_onnx["segments"]:
1069
+ logger.info(f"PIPER TTS: {speakers_vits_onnx}")
1070
+ segments_vits_onnx_tts(filtered_vits_onnx, TRANSLATE_AUDIO_TO) # wav
1071
+ if filtered_openai_tts["segments"]:
1072
+ logger.info(f"OpenAI TTS: {speakers_openai_tts}")
1073
+ segments_openai_tts(filtered_openai_tts, TRANSLATE_AUDIO_TO) # wav
1074
+
1075
+ [result.pop("tts_name", None) for result in result_diarize["segments"]]
1076
+ return [
1077
+ speakers_edge,
1078
+ speakers_bark,
1079
+ speakers_vits,
1080
+ speakers_coqui,
1081
+ speakers_vits_onnx,
1082
+ speakers_openai_tts
1083
+ ]
1084
+
1085
+
1086
+ def accelerate_segments(
1087
+ result_diarize,
1088
+ max_accelerate_audio,
1089
+ valid_speakers,
1090
+ acceleration_rate_regulation=False,
1091
+ folder_output="audio2",
1092
+ ):
1093
+ logger.info("Apply acceleration")
1094
+
1095
+ (
1096
+ speakers_edge,
1097
+ speakers_bark,
1098
+ speakers_vits,
1099
+ speakers_coqui,
1100
+ speakers_vits_onnx,
1101
+ speakers_openai_tts
1102
+ ) = valid_speakers
1103
+
1104
+ create_directories(f"{folder_output}/audio/")
1105
+ remove_directory_contents(f"{folder_output}/audio/")
1106
+
1107
+ audio_files = []
1108
+ speakers_list = []
1109
+
1110
+ max_count_segments_idx = len(result_diarize["segments"]) - 1
1111
+
1112
+ for i, segment in tqdm(enumerate(result_diarize["segments"])):
1113
+ text = segment["text"] # noqa
1114
+ start = segment["start"]
1115
+ end = segment["end"]
1116
+ speaker = segment["speaker"]
1117
+
1118
+ # find name audio
1119
+ # if speaker in speakers_edge:
1120
+ filename = f"audio/{start}.ogg"
1121
+ # elif speaker in speakers_bark + speakers_vits + speakers_coqui + speakers_vits_onnx:
1122
+ # filename = f"audio/{start}.wav" # wav
1123
+
1124
+ # duration
1125
+ duration_true = end - start
1126
+ duration_tts = librosa.get_duration(filename=filename)
1127
+
1128
+ # Accelerate percentage
1129
+ acc_percentage = duration_tts / duration_true
1130
+
1131
+ # Smoth
1132
+ if acceleration_rate_regulation and acc_percentage >= 1.3:
1133
+ try:
1134
+ next_segment = result_diarize["segments"][
1135
+ min(max_count_segments_idx, i + 1)
1136
+ ]
1137
+ next_start = next_segment["start"]
1138
+ next_speaker = next_segment["speaker"]
1139
+ duration_with_next_start = next_start - start
1140
+
1141
+ if duration_with_next_start > duration_true:
1142
+ extra_time = duration_with_next_start - duration_true
1143
+
1144
+ if speaker == next_speaker:
1145
+ # half
1146
+ smoth_duration = duration_true + (extra_time * 0.5)
1147
+ else:
1148
+ # 7/10
1149
+ smoth_duration = duration_true + (extra_time * 0.7)
1150
+ logger.debug(
1151
+ f"Base acc: {acc_percentage}, "
1152
+ f"smoth acc: {duration_tts / smoth_duration}"
1153
+ )
1154
+ acc_percentage = max(1.2, (duration_tts / smoth_duration))
1155
+
1156
+ except Exception as error:
1157
+ logger.error(str(error))
1158
+
1159
+ if acc_percentage > max_accelerate_audio:
1160
+ acc_percentage = max_accelerate_audio
1161
+ elif acc_percentage <= 1.15 and acc_percentage >= 0.8:
1162
+ acc_percentage = 1.0
1163
+ elif acc_percentage <= 0.79:
1164
+ acc_percentage = 0.8
1165
+
1166
+ # Round
1167
+ acc_percentage = round(acc_percentage + 0.0, 1)
1168
+
1169
+ # Format read if need
1170
+ if speaker in speakers_edge:
1171
+ info_enc = sf.info(filename).format
1172
+ else:
1173
+ info_enc = "OGG"
1174
+
1175
+ # Apply aceleration or opposite to the audio file in folder_output folder
1176
+ if acc_percentage == 1.0 and info_enc == "OGG":
1177
+ copy_files(filename, f"{folder_output}{os.sep}audio")
1178
+ else:
1179
+ os.system(
1180
+ f"ffmpeg -y -loglevel panic -i {filename} -filter:a atempo={acc_percentage} {folder_output}/{filename}"
1181
+ )
1182
+
1183
+ if logger.isEnabledFor(logging.DEBUG):
1184
+ duration_create = librosa.get_duration(
1185
+ filename=f"{folder_output}/{filename}"
1186
+ )
1187
+ logger.debug(
1188
+ f"acc_percen is {acc_percentage}, tts duration "
1189
+ f"is {duration_tts}, new duration is {duration_create}"
1190
+ f", for {filename}"
1191
+ )
1192
+
1193
+ audio_files.append(f"{folder_output}/{filename}")
1194
+ speaker = "TTS Speaker {:02d}".format(int(speaker[-2:]) + 1)
1195
+ speakers_list.append(speaker)
1196
+
1197
+ return audio_files, speakers_list
1198
+
1199
+
1200
+ # =====================================
1201
+ # Tone color converter
1202
+ # =====================================
1203
+
1204
+
1205
+ def se_process_audio_segments(
1206
+ source_seg, tone_color_converter, device, remove_previous_processed=True
1207
+ ):
1208
+ # list wav seg
1209
+ source_audio_segs = glob.glob(f"{source_seg}/*.wav")
1210
+ if not source_audio_segs:
1211
+ raise ValueError(
1212
+ f"No audio segments found in {str(source_audio_segs)}"
1213
+ )
1214
+
1215
+ source_se_path = os.path.join(source_seg, "se.pth")
1216
+
1217
+ # if exist not create wav
1218
+ if os.path.isfile(source_se_path):
1219
+ se = torch.load(source_se_path).to(device)
1220
+ logger.debug(f"Previous created {source_se_path}")
1221
+ else:
1222
+ se = tone_color_converter.extract_se(source_audio_segs, source_se_path)
1223
+
1224
+ return se
1225
+
1226
+
1227
+ def create_wav_vc(
1228
+ valid_speakers,
1229
+ segments_base,
1230
+ audio_name,
1231
+ max_segments=10,
1232
+ target_dir="processed",
1233
+ get_vocals_dereverb=False,
1234
+ ):
1235
+ # valid_speakers = list({item['speaker'] for item in segments_base})
1236
+
1237
+ # Before function delete automatic delete_previous_automatic
1238
+ output_dir = os.path.join(".", target_dir) # remove content
1239
+ # remove_directory_contents(output_dir)
1240
+
1241
+ path_source_segments = []
1242
+ path_target_segments = []
1243
+ for speaker in valid_speakers:
1244
+ filtered_speaker = [
1245
+ segment
1246
+ for segment in segments_base
1247
+ if segment["speaker"] == speaker
1248
+ ]
1249
+ if len(filtered_speaker) > 4:
1250
+ filtered_speaker = filtered_speaker[1:]
1251
+
1252
+ dir_name_speaker = speaker + audio_name
1253
+ dir_name_speaker_tts = "tts" + speaker + audio_name
1254
+ dir_path_speaker = os.path.join(output_dir, dir_name_speaker)
1255
+ dir_path_speaker_tts = os.path.join(output_dir, dir_name_speaker_tts)
1256
+ create_directories([dir_path_speaker, dir_path_speaker_tts])
1257
+
1258
+ path_target_segments.append(dir_path_speaker)
1259
+ path_source_segments.append(dir_path_speaker_tts)
1260
+
1261
+ # create wav
1262
+ max_segments_count = 0
1263
+ for seg in filtered_speaker:
1264
+ duration = float(seg["end"]) - float(seg["start"])
1265
+ if duration > 3.0 and duration < 18.0:
1266
+ logger.info(
1267
+ f'Processing segment: {seg["start"]}, {seg["end"]}, {seg["speaker"]}, {duration}, {seg["text"]}'
1268
+ )
1269
+ name_new_wav = str(seg["start"])
1270
+
1271
+ check_segment_audio_target_file = os.path.join(
1272
+ dir_path_speaker, f"{name_new_wav}.wav"
1273
+ )
1274
+
1275
+ if os.path.exists(check_segment_audio_target_file):
1276
+ logger.debug(
1277
+ "Segment vc source exists: "
1278
+ f"{check_segment_audio_target_file}"
1279
+ )
1280
+ pass
1281
+ else:
1282
+ create_wav_file_vc(
1283
+ sample_name=name_new_wav,
1284
+ audio_wav="audio.wav",
1285
+ start=(float(seg["start"]) + 1.0),
1286
+ end=(float(seg["end"]) - 1.0),
1287
+ output_final_path=dir_path_speaker,
1288
+ get_vocals_dereverb=get_vocals_dereverb,
1289
+ )
1290
+
1291
+ file_name_tts = f"audio2/audio/{str(seg['start'])}.ogg"
1292
+ # copy_files(file_name_tts, os.path.join(output_dir, dir_name_speaker_tts)
1293
+ convert_to_xtts_good_sample(
1294
+ file_name_tts, dir_path_speaker_tts
1295
+ )
1296
+
1297
+ max_segments_count += 1
1298
+ if max_segments_count == max_segments:
1299
+ break
1300
+
1301
+ if max_segments_count == 0:
1302
+ logger.info("Taking the first segment")
1303
+ seg = filtered_speaker[0]
1304
+ logger.info(
1305
+ f'Processing segment: {seg["start"]}, {seg["end"]}, {seg["speaker"]}, {seg["text"]}'
1306
+ )
1307
+ max_duration = float(seg["end"]) - float(seg["start"])
1308
+ max_duration = max(1.0, min(max_duration, 18.0))
1309
+
1310
+ name_new_wav = str(seg["start"])
1311
+ create_wav_file_vc(
1312
+ sample_name=name_new_wav,
1313
+ audio_wav="audio.wav",
1314
+ start=(float(seg["start"])),
1315
+ end=(float(seg["start"]) + max_duration),
1316
+ output_final_path=dir_path_speaker,
1317
+ get_vocals_dereverb=get_vocals_dereverb,
1318
+ )
1319
+
1320
+ file_name_tts = f"audio2/audio/{str(seg['start'])}.ogg"
1321
+ # copy_files(file_name_tts, os.path.join(output_dir, dir_name_speaker_tts)
1322
+ convert_to_xtts_good_sample(file_name_tts, dir_path_speaker_tts)
1323
+
1324
+ logger.debug(f"Base: {str(path_source_segments)}")
1325
+ logger.debug(f"Target: {str(path_target_segments)}")
1326
+
1327
+ return path_source_segments, path_target_segments
1328
+
1329
+
1330
+ def toneconverter_openvoice(
1331
+ result_diarize,
1332
+ preprocessor_max_segments,
1333
+ remove_previous_process=True,
1334
+ get_vocals_dereverb=False,
1335
+ model="openvoice",
1336
+ ):
1337
+ audio_path = "audio.wav"
1338
+ # se_path = "se.pth"
1339
+ target_dir = "processed"
1340
+ create_directories(target_dir)
1341
+
1342
+ from openvoice import se_extractor
1343
+ from openvoice.api import ToneColorConverter
1344
+
1345
+ audio_name = f"{os.path.basename(audio_path).rsplit('.', 1)[0]}_{se_extractor.hash_numpy_array(audio_path)}"
1346
+ # se_path = os.path.join(target_dir, audio_name, 'se.pth')
1347
+
1348
+ # create wav seg original and target
1349
+
1350
+ valid_speakers = list(
1351
+ {item["speaker"] for item in result_diarize["segments"]}
1352
+ )
1353
+
1354
+ logger.info("Openvoice preprocessor...")
1355
+
1356
+ if remove_previous_process:
1357
+ remove_directory_contents(target_dir)
1358
+
1359
+ path_source_segments, path_target_segments = create_wav_vc(
1360
+ valid_speakers,
1361
+ result_diarize["segments"],
1362
+ audio_name,
1363
+ max_segments=preprocessor_max_segments,
1364
+ get_vocals_dereverb=get_vocals_dereverb,
1365
+ )
1366
+
1367
+ logger.info("Openvoice loading model...")
1368
+ model_path_openvoice = "./OPENVOICE_MODELS"
1369
+ url_model_openvoice = "https://huggingface.co/myshell-ai/OpenVoice/resolve/main/checkpoints/converter"
1370
+
1371
+ if "v2" in model:
1372
+ model_path = os.path.join(model_path_openvoice, "v2")
1373
+ url_model_openvoice = url_model_openvoice.replace(
1374
+ "OpenVoice", "OpenVoiceV2"
1375
+ ).replace("checkpoints/", "")
1376
+ else:
1377
+ model_path = os.path.join(model_path_openvoice, "v1")
1378
+ create_directories(model_path)
1379
+
1380
+ config_url = f"{url_model_openvoice}/config.json"
1381
+ checkpoint_url = f"{url_model_openvoice}/checkpoint.pth"
1382
+
1383
+ config_path = download_manager(url=config_url, path=model_path)
1384
+ checkpoint_path = download_manager(
1385
+ url=checkpoint_url, path=model_path
1386
+ )
1387
+
1388
+ device = os.environ.get("SONITR_DEVICE")
1389
+ tone_color_converter = ToneColorConverter(config_path, device=device)
1390
+ tone_color_converter.load_ckpt(checkpoint_path)
1391
+
1392
+ logger.info("Openvoice tone color converter:")
1393
+ global_progress_bar = tqdm(total=len(result_diarize["segments"]), desc="Progress")
1394
+
1395
+ for source_seg, target_seg, speaker in zip(
1396
+ path_source_segments, path_target_segments, valid_speakers
1397
+ ):
1398
+ # source_se_path = os.path.join(source_seg, 'se.pth')
1399
+ source_se = se_process_audio_segments(source_seg, tone_color_converter, device)
1400
+ # target_se_path = os.path.join(target_seg, 'se.pth')
1401
+ target_se = se_process_audio_segments(target_seg, tone_color_converter, device)
1402
+
1403
+ # Iterate throw segments
1404
+ encode_message = "@MyShell"
1405
+ filtered_speaker = [
1406
+ segment
1407
+ for segment in result_diarize["segments"]
1408
+ if segment["speaker"] == speaker
1409
+ ]
1410
+ for seg in filtered_speaker:
1411
+ src_path = (
1412
+ save_path
1413
+ ) = f"audio2/audio/{str(seg['start'])}.ogg" # overwrite
1414
+ logger.debug(f"{src_path}")
1415
+
1416
+ tone_color_converter.convert(
1417
+ audio_src_path=src_path,
1418
+ src_se=source_se,
1419
+ tgt_se=target_se,
1420
+ output_path=save_path,
1421
+ message=encode_message,
1422
+ )
1423
+
1424
+ global_progress_bar.update(1)
1425
+
1426
+ global_progress_bar.close()
1427
+
1428
+ try:
1429
+ del tone_color_converter
1430
+ gc.collect()
1431
+ torch.cuda.empty_cache()
1432
+ except Exception as error:
1433
+ logger.error(str(error))
1434
+ gc.collect()
1435
+ torch.cuda.empty_cache()
1436
+
1437
+
1438
+ def toneconverter_freevc(
1439
+ result_diarize,
1440
+ remove_previous_process=True,
1441
+ get_vocals_dereverb=False,
1442
+ ):
1443
+ audio_path = "audio.wav"
1444
+ target_dir = "processed"
1445
+ create_directories(target_dir)
1446
+
1447
+ from openvoice import se_extractor
1448
+
1449
+ audio_name = f"{os.path.basename(audio_path).rsplit('.', 1)[0]}_{se_extractor.hash_numpy_array(audio_path)}"
1450
+
1451
+ # create wav seg; original is target and dubbing is source
1452
+ valid_speakers = list(
1453
+ {item["speaker"] for item in result_diarize["segments"]}
1454
+ )
1455
+
1456
+ logger.info("FreeVC preprocessor...")
1457
+
1458
+ if remove_previous_process:
1459
+ remove_directory_contents(target_dir)
1460
+
1461
+ path_source_segments, path_target_segments = create_wav_vc(
1462
+ valid_speakers,
1463
+ result_diarize["segments"],
1464
+ audio_name,
1465
+ max_segments=1,
1466
+ get_vocals_dereverb=get_vocals_dereverb,
1467
+ )
1468
+
1469
+ logger.info("FreeVC loading model...")
1470
+ device_id = os.environ.get("SONITR_DEVICE")
1471
+ device = None if device_id == "cpu" else device_id
1472
+ try:
1473
+ from TTS.api import TTS
1474
+ tts = TTS(
1475
+ model_name="voice_conversion_models/multilingual/vctk/freevc24",
1476
+ progress_bar=False
1477
+ ).to(device)
1478
+ except Exception as error:
1479
+ logger.error(str(error))
1480
+ logger.error("Error loading the FreeVC model.")
1481
+ return
1482
+
1483
+ logger.info("FreeVC process:")
1484
+ global_progress_bar = tqdm(total=len(result_diarize["segments"]), desc="Progress")
1485
+
1486
+ for source_seg, target_seg, speaker in zip(
1487
+ path_source_segments, path_target_segments, valid_speakers
1488
+ ):
1489
+
1490
+ filtered_speaker = [
1491
+ segment
1492
+ for segment in result_diarize["segments"]
1493
+ if segment["speaker"] == speaker
1494
+ ]
1495
+
1496
+ files_and_directories = os.listdir(target_seg)
1497
+ wav_files = [file for file in files_and_directories if file.endswith(".wav")]
1498
+ original_wav_audio_segment = os.path.join(target_seg, wav_files[0])
1499
+
1500
+ for seg in filtered_speaker:
1501
+
1502
+ src_path = (
1503
+ save_path
1504
+ ) = f"audio2/audio/{str(seg['start'])}.ogg" # overwrite
1505
+ logger.debug(f"{src_path} - {original_wav_audio_segment}")
1506
+
1507
+ wav = tts.voice_conversion(
1508
+ source_wav=src_path,
1509
+ target_wav=original_wav_audio_segment,
1510
+ )
1511
+
1512
+ sf.write(
1513
+ file=save_path,
1514
+ samplerate=tts.voice_converter.vc_config.audio.output_sample_rate,
1515
+ data=wav,
1516
+ format="ogg",
1517
+ subtype="vorbis",
1518
+ )
1519
+
1520
+ global_progress_bar.update(1)
1521
+
1522
+ global_progress_bar.close()
1523
+
1524
+ try:
1525
+ del tts
1526
+ gc.collect()
1527
+ torch.cuda.empty_cache()
1528
+ except Exception as error:
1529
+ logger.error(str(error))
1530
+ gc.collect()
1531
+ torch.cuda.empty_cache()
1532
+
1533
+
1534
+ def toneconverter(
1535
+ result_diarize,
1536
+ preprocessor_max_segments,
1537
+ remove_previous_process=True,
1538
+ get_vocals_dereverb=False,
1539
+ method_vc="freevc"
1540
+ ):
1541
+
1542
+ if method_vc == "freevc":
1543
+ if preprocessor_max_segments > 1:
1544
+ logger.info("FreeVC only uses one segment.")
1545
+ return toneconverter_freevc(
1546
+ result_diarize,
1547
+ remove_previous_process=remove_previous_process,
1548
+ get_vocals_dereverb=get_vocals_dereverb,
1549
+ )
1550
+ elif "openvoice" in method_vc:
1551
+ return toneconverter_openvoice(
1552
+ result_diarize,
1553
+ preprocessor_max_segments,
1554
+ remove_previous_process=remove_previous_process,
1555
+ get_vocals_dereverb=get_vocals_dereverb,
1556
+ model=method_vc,
1557
+ )
1558
+
1559
+
1560
+ if __name__ == "__main__":
1561
+ from segments import result_diarize
1562
+
1563
+ audio_segmentation_to_voice(
1564
+ result_diarize,
1565
+ TRANSLATE_AUDIO_TO="en",
1566
+ max_accelerate_audio=2.1,
1567
+ is_gui=True,
1568
+ tts_voice00="en-facebook-mms VITS",
1569
+ tts_voice01="en-CA-ClaraNeural-Female",
1570
+ tts_voice02="en-GB-ThomasNeural-Male",
1571
+ tts_voice03="en-GB-SoniaNeural-Female",
1572
+ tts_voice04="en-NZ-MitchellNeural-Male",
1573
+ tts_voice05="en-GB-MaisieNeural-Female",
1574
+ )
soni_translate/translate_segments.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ from deep_translator import GoogleTranslator
3
+ from itertools import chain
4
+ import copy
5
+ from .language_configuration import fix_code_language, INVERTED_LANGUAGES
6
+ from .logging_setup import logger
7
+ import re
8
+ import json
9
+ import time
10
+
11
+ TRANSLATION_PROCESS_OPTIONS = [
12
+ "google_translator_batch",
13
+ "google_translator",
14
+ "gpt-3.5-turbo-0125_batch",
15
+ "gpt-3.5-turbo-0125",
16
+ "gpt-4-turbo-preview_batch",
17
+ "gpt-4-turbo-preview",
18
+ "disable_translation",
19
+ ]
20
+ DOCS_TRANSLATION_PROCESS_OPTIONS = [
21
+ "google_translator",
22
+ "gpt-3.5-turbo-0125",
23
+ "gpt-4-turbo-preview",
24
+ "disable_translation",
25
+ ]
26
+
27
+
28
+ def translate_iterative(segments, target, source=None):
29
+ """
30
+ Translate text segments individually to the specified language.
31
+
32
+ Parameters:
33
+ - segments (list): A list of dictionaries with 'text' as a key for
34
+ segment text.
35
+ - target (str): Target language code.
36
+ - source (str, optional): Source language code. Defaults to None.
37
+
38
+ Returns:
39
+ - list: Translated text segments in the target language.
40
+
41
+ Notes:
42
+ - Translates each segment using Google Translate.
43
+
44
+ Example:
45
+ segments = [{'text': 'first segment.'}, {'text': 'second segment.'}]
46
+ translated_segments = translate_iterative(segments, 'es')
47
+ """
48
+
49
+ segments_ = copy.deepcopy(segments)
50
+
51
+ if (
52
+ not source
53
+ ):
54
+ logger.debug("No source language")
55
+ source = "auto"
56
+
57
+ translator = GoogleTranslator(source=source, target=target)
58
+
59
+ for line in tqdm(range(len(segments_))):
60
+ text = segments_[line]["text"]
61
+ translated_line = translator.translate(text.strip())
62
+ segments_[line]["text"] = translated_line
63
+
64
+ return segments_
65
+
66
+
67
+ def verify_translate(
68
+ segments,
69
+ segments_copy,
70
+ translated_lines,
71
+ target,
72
+ source
73
+ ):
74
+ """
75
+ Verify integrity and translate segments if lengths match, otherwise
76
+ switch to iterative translation.
77
+ """
78
+ if len(segments) == len(translated_lines):
79
+ for line in range(len(segments_copy)):
80
+ logger.debug(
81
+ f"{segments_copy[line]['text']} >> "
82
+ f"{translated_lines[line].strip()}"
83
+ )
84
+ segments_copy[line]["text"] = translated_lines[
85
+ line].replace("\t", "").replace("\n", "").strip()
86
+ return segments_copy
87
+ else:
88
+ logger.error(
89
+ "The translation failed, switching to google_translate iterative. "
90
+ f"{len(segments), len(translated_lines)}"
91
+ )
92
+ return translate_iterative(segments, target, source)
93
+
94
+
95
+ def translate_batch(segments, target, chunk_size=2000, source=None):
96
+ """
97
+ Translate a batch of text segments into the specified language in chunks,
98
+ respecting the character limit.
99
+
100
+ Parameters:
101
+ - segments (list): List of dictionaries with 'text' as a key for segment
102
+ text.
103
+ - target (str): Target language code.
104
+ - chunk_size (int, optional): Maximum character limit for each translation
105
+ chunk (default is 2000; max 5000).
106
+ - source (str, optional): Source language code. Defaults to None.
107
+
108
+ Returns:
109
+ - list: Translated text segments in the target language.
110
+
111
+ Notes:
112
+ - Splits input segments into chunks respecting the character limit for
113
+ translation.
114
+ - Translates the chunks using Google Translate.
115
+ - If chunked translation fails, switches to iterative translation using
116
+ `translate_iterative()`.
117
+
118
+ Example:
119
+ segments = [{'text': 'first segment.'}, {'text': 'second segment.'}]
120
+ translated = translate_batch(segments, 'es', chunk_size=4000, source='en')
121
+ """
122
+
123
+ segments_copy = copy.deepcopy(segments)
124
+
125
+ if (
126
+ not source
127
+ ):
128
+ logger.debug("No source language")
129
+ source = "auto"
130
+
131
+ # Get text
132
+ text_lines = []
133
+ for line in range(len(segments_copy)):
134
+ text = segments_copy[line]["text"].strip()
135
+ text_lines.append(text)
136
+
137
+ # chunk limit
138
+ text_merge = []
139
+ actual_chunk = ""
140
+ global_text_list = []
141
+ actual_text_list = []
142
+ for one_line in text_lines:
143
+ one_line = " " if not one_line else one_line
144
+ if (len(actual_chunk) + len(one_line)) <= chunk_size:
145
+ if actual_chunk:
146
+ actual_chunk += " ||||| "
147
+ actual_chunk += one_line
148
+ actual_text_list.append(one_line)
149
+ else:
150
+ text_merge.append(actual_chunk)
151
+ actual_chunk = one_line
152
+ global_text_list.append(actual_text_list)
153
+ actual_text_list = [one_line]
154
+ if actual_chunk:
155
+ text_merge.append(actual_chunk)
156
+ global_text_list.append(actual_text_list)
157
+
158
+ # translate chunks
159
+ progress_bar = tqdm(total=len(segments), desc="Translating")
160
+ translator = GoogleTranslator(source=source, target=target)
161
+ split_list = []
162
+ try:
163
+ for text, text_iterable in zip(text_merge, global_text_list):
164
+ translated_line = translator.translate(text.strip())
165
+ split_text = translated_line.split("|||||")
166
+ if len(split_text) == len(text_iterable):
167
+ progress_bar.update(len(split_text))
168
+ else:
169
+ logger.debug(
170
+ "Chunk fixing iteratively. Len chunk: "
171
+ f"{len(split_text)}, expected: {len(text_iterable)}"
172
+ )
173
+ split_text = []
174
+ for txt_iter in text_iterable:
175
+ translated_txt = translator.translate(txt_iter.strip())
176
+ split_text.append(translated_txt)
177
+ progress_bar.update(1)
178
+ split_list.append(split_text)
179
+ progress_bar.close()
180
+ except Exception as error:
181
+ progress_bar.close()
182
+ logger.error(str(error))
183
+ logger.warning(
184
+ "The translation in chunks failed, switching to iterative."
185
+ " Related: too many request"
186
+ ) # use proxy or less chunk size
187
+ return translate_iterative(segments, target, source)
188
+
189
+ # un chunk
190
+ translated_lines = list(chain.from_iterable(split_list))
191
+
192
+ return verify_translate(
193
+ segments, segments_copy, translated_lines, target, source
194
+ )
195
+
196
+
197
+ def call_gpt_translate(
198
+ client,
199
+ model,
200
+ system_prompt,
201
+ user_prompt,
202
+ original_text=None,
203
+ batch_lines=None,
204
+ ):
205
+
206
+ # https://platform.openai.com/docs/guides/text-generation/json-mode
207
+ response = client.chat.completions.create(
208
+ model=model,
209
+ response_format={"type": "json_object"},
210
+ messages=[
211
+ {"role": "system", "content": system_prompt},
212
+ {"role": "user", "content": user_prompt}
213
+ ]
214
+ )
215
+ result = response.choices[0].message.content
216
+ logger.debug(f"Result: {str(result)}")
217
+
218
+ try:
219
+ translation = json.loads(result)
220
+ except Exception as error:
221
+ match_result = re.search(r'\{.*?\}', result)
222
+ if match_result:
223
+ logger.error(str(error))
224
+ json_str = match_result.group(0)
225
+ translation = json.loads(json_str)
226
+ else:
227
+ raise error
228
+
229
+ # Get valid data
230
+ if batch_lines:
231
+ for conversation in translation.values():
232
+ if isinstance(conversation, dict):
233
+ conversation = list(conversation.values())[0]
234
+ if (
235
+ list(
236
+ original_text["conversation"][0].values()
237
+ )[0].strip() ==
238
+ list(conversation[0].values())[0].strip()
239
+ ):
240
+ continue
241
+ if len(conversation) == batch_lines:
242
+ break
243
+
244
+ fix_conversation_length = []
245
+ for line in conversation:
246
+ for speaker_code, text_tr in line.items():
247
+ fix_conversation_length.append({speaker_code: text_tr})
248
+
249
+ logger.debug(f"Data batch: {str(fix_conversation_length)}")
250
+ logger.debug(
251
+ f"Lines Received: {len(fix_conversation_length)},"
252
+ f" expected: {batch_lines}"
253
+ )
254
+
255
+ return fix_conversation_length
256
+
257
+ else:
258
+ if isinstance(translation, dict):
259
+ translation = list(translation.values())[0]
260
+ if isinstance(translation, list):
261
+ translation = translation[0]
262
+ if isinstance(translation, set):
263
+ translation = list(translation)[0]
264
+ if not isinstance(translation, str):
265
+ raise ValueError(f"No valid response received: {str(translation)}")
266
+
267
+ return translation
268
+
269
+
270
+ def gpt_sequential(segments, model, target, source=None):
271
+ from openai import OpenAI
272
+
273
+ translated_segments = copy.deepcopy(segments)
274
+
275
+ client = OpenAI()
276
+ progress_bar = tqdm(total=len(segments), desc="Translating")
277
+
278
+ lang_tg = re.sub(r'\([^)]*\)', '', INVERTED_LANGUAGES[target]).strip()
279
+ lang_sc = ""
280
+ if source:
281
+ lang_sc = re.sub(r'\([^)]*\)', '', INVERTED_LANGUAGES[source]).strip()
282
+
283
+ fixed_target = fix_code_language(target)
284
+ fixed_source = fix_code_language(source) if source else "auto"
285
+
286
+ system_prompt = "Machine translation designed to output the translated_text JSON."
287
+
288
+ for i, line in enumerate(translated_segments):
289
+ text = line["text"].strip()
290
+ start = line["start"]
291
+ user_prompt = f"Translate the following {lang_sc} text into {lang_tg}, write the fully translated text and nothing more:\n{text}"
292
+
293
+ time.sleep(0.5)
294
+
295
+ try:
296
+ translated_text = call_gpt_translate(
297
+ client,
298
+ model,
299
+ system_prompt,
300
+ user_prompt,
301
+ )
302
+
303
+ except Exception as error:
304
+ logger.error(
305
+ f"{str(error)} >> The text of segment {start} "
306
+ "is being corrected with Google Translate"
307
+ )
308
+ translator = GoogleTranslator(
309
+ source=fixed_source, target=fixed_target
310
+ )
311
+ translated_text = translator.translate(text.strip())
312
+
313
+ translated_segments[i]["text"] = translated_text.strip()
314
+ progress_bar.update(1)
315
+
316
+ progress_bar.close()
317
+
318
+ return translated_segments
319
+
320
+
321
+ def gpt_batch(segments, model, target, token_batch_limit=900, source=None):
322
+ from openai import OpenAI
323
+ import tiktoken
324
+
325
+ token_batch_limit = max(100, (token_batch_limit - 40) // 2)
326
+ progress_bar = tqdm(total=len(segments), desc="Translating")
327
+ segments_copy = copy.deepcopy(segments)
328
+ encoding = tiktoken.get_encoding("cl100k_base")
329
+ client = OpenAI()
330
+
331
+ lang_tg = re.sub(r'\([^)]*\)', '', INVERTED_LANGUAGES[target]).strip()
332
+ lang_sc = ""
333
+ if source:
334
+ lang_sc = re.sub(r'\([^)]*\)', '', INVERTED_LANGUAGES[source]).strip()
335
+
336
+ fixed_target = fix_code_language(target)
337
+ fixed_source = fix_code_language(source) if source else "auto"
338
+
339
+ name_speaker = "ABCDEFGHIJKL"
340
+
341
+ translated_lines = []
342
+ text_data_dict = []
343
+ num_tokens = 0
344
+ count_sk = {char: 0 for char in "ABCDEFGHIJKL"}
345
+
346
+ for i, line in enumerate(segments_copy):
347
+ text = line["text"]
348
+ speaker = line["speaker"]
349
+ last_start = line["start"]
350
+ # text_data_dict.append({str(int(speaker[-1])+1): text})
351
+ index_sk = int(speaker[-2:])
352
+ character_sk = name_speaker[index_sk]
353
+ count_sk[character_sk] += 1
354
+ code_sk = character_sk+str(count_sk[character_sk])
355
+ text_data_dict.append({code_sk: text})
356
+ num_tokens += len(encoding.encode(text)) + 7
357
+ if num_tokens >= token_batch_limit or i == len(segments_copy)-1:
358
+ try:
359
+ batch_lines = len(text_data_dict)
360
+ batch_conversation = {"conversation": copy.deepcopy(text_data_dict)}
361
+ # Reset vars
362
+ num_tokens = 0
363
+ text_data_dict = []
364
+ count_sk = {char: 0 for char in "ABCDEFGHIJKL"}
365
+ # Process translation
366
+ # https://arxiv.org/pdf/2309.03409.pdf
367
+ system_prompt = f"Machine translation designed to output the translated_conversation key JSON containing a list of {batch_lines} items."
368
+ user_prompt = f"Translate each of the following text values in conversation{' from' if lang_sc else ''} {lang_sc} to {lang_tg}:\n{batch_conversation}"
369
+ logger.debug(f"Prompt: {str(user_prompt)}")
370
+
371
+ conversation = call_gpt_translate(
372
+ client,
373
+ model,
374
+ system_prompt,
375
+ user_prompt,
376
+ original_text=batch_conversation,
377
+ batch_lines=batch_lines,
378
+ )
379
+
380
+ if len(conversation) < batch_lines:
381
+ raise ValueError(
382
+ "Incomplete result received. Batch lines: "
383
+ f"{len(conversation)}, expected: {batch_lines}"
384
+ )
385
+
386
+ for i, translated_text in enumerate(conversation):
387
+ if i+1 > batch_lines:
388
+ break
389
+ translated_lines.append(list(translated_text.values())[0])
390
+
391
+ progress_bar.update(batch_lines)
392
+
393
+ except Exception as error:
394
+ logger.error(str(error))
395
+
396
+ first_start = segments_copy[max(0, i-(batch_lines-1))]["start"]
397
+ logger.warning(
398
+ f"The batch from {first_start} to {last_start} "
399
+ "failed, is being corrected with Google Translate"
400
+ )
401
+
402
+ translator = GoogleTranslator(
403
+ source=fixed_source,
404
+ target=fixed_target
405
+ )
406
+
407
+ for txt_source in batch_conversation["conversation"]:
408
+ translated_txt = translator.translate(
409
+ list(txt_source.values())[0].strip()
410
+ )
411
+ translated_lines.append(translated_txt.strip())
412
+ progress_bar.update(1)
413
+
414
+ progress_bar.close()
415
+
416
+ return verify_translate(
417
+ segments, segments_copy, translated_lines, fixed_target, fixed_source
418
+ )
419
+
420
+
421
+ def translate_text(
422
+ segments,
423
+ target,
424
+ translation_process="google_translator_batch",
425
+ chunk_size=4500,
426
+ source=None,
427
+ token_batch_limit=1000,
428
+ ):
429
+ """Translates text segments using a specified process."""
430
+ match translation_process:
431
+ case "google_translator_batch":
432
+ return translate_batch(
433
+ segments,
434
+ fix_code_language(target),
435
+ chunk_size,
436
+ fix_code_language(source)
437
+ )
438
+ case "google_translator":
439
+ return translate_iterative(
440
+ segments,
441
+ fix_code_language(target),
442
+ fix_code_language(source)
443
+ )
444
+ case model if model in ["gpt-3.5-turbo-0125", "gpt-4-turbo-preview"]:
445
+ return gpt_sequential(segments, model, target, source)
446
+ case model if model in ["gpt-3.5-turbo-0125_batch", "gpt-4-turbo-preview_batch",]:
447
+ return gpt_batch(
448
+ segments,
449
+ translation_process.replace("_batch", ""),
450
+ target,
451
+ token_batch_limit,
452
+ source
453
+ )
454
+ case "disable_translation":
455
+ return segments
456
+ case _:
457
+ raise ValueError("No valid translation process")
soni_translate/utils.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, zipfile, rarfile, shutil, subprocess, shlex, sys # noqa
2
+ from .logging_setup import logger
3
+ from urllib.parse import urlparse
4
+ from IPython.utils import capture
5
+ import re
6
+
7
+ VIDEO_EXTENSIONS = [
8
+ ".mp4",
9
+ ".avi",
10
+ ".mov",
11
+ ".mkv",
12
+ ".wmv",
13
+ ".flv",
14
+ ".webm",
15
+ ".m4v",
16
+ ".mpeg",
17
+ ".mpg",
18
+ ".3gp"
19
+ ]
20
+
21
+ AUDIO_EXTENSIONS = [
22
+ ".mp3",
23
+ ".wav",
24
+ ".aiff",
25
+ ".aif",
26
+ ".flac",
27
+ ".aac",
28
+ ".ogg",
29
+ ".wma",
30
+ ".m4a",
31
+ ".alac",
32
+ ".pcm",
33
+ ".opus",
34
+ ".ape",
35
+ ".amr",
36
+ ".ac3",
37
+ ".vox",
38
+ ".caf"
39
+ ]
40
+
41
+ SUBTITLE_EXTENSIONS = [
42
+ ".srt",
43
+ ".vtt",
44
+ ".ass"
45
+ ]
46
+
47
+
48
+ def run_command(command):
49
+ logger.debug(command)
50
+ if isinstance(command, str):
51
+ command = shlex.split(command)
52
+
53
+ sub_params = {
54
+ "stdout": subprocess.PIPE,
55
+ "stderr": subprocess.PIPE,
56
+ "creationflags": subprocess.CREATE_NO_WINDOW
57
+ if sys.platform == "win32"
58
+ else 0,
59
+ }
60
+ process_command = subprocess.Popen(command, **sub_params)
61
+ output, errors = process_command.communicate()
62
+ if (
63
+ process_command.returncode != 0
64
+ ): # or not os.path.exists(mono_path) or os.path.getsize(mono_path) == 0:
65
+ logger.error("Error comnand")
66
+ raise Exception(errors.decode())
67
+
68
+
69
+ def print_tree_directory(root_dir, indent=""):
70
+ if not os.path.exists(root_dir):
71
+ logger.error(f"{indent} Invalid directory or file: {root_dir}")
72
+ return
73
+
74
+ items = os.listdir(root_dir)
75
+
76
+ for index, item in enumerate(sorted(items)):
77
+ item_path = os.path.join(root_dir, item)
78
+ is_last_item = index == len(items) - 1
79
+
80
+ if os.path.isfile(item_path) and item_path.endswith(".zip"):
81
+ with zipfile.ZipFile(item_path, "r") as zip_file:
82
+ print(
83
+ f"{indent}{'└──' if is_last_item else '├──'} {item} (zip file)"
84
+ )
85
+ zip_contents = zip_file.namelist()
86
+ for zip_item in sorted(zip_contents):
87
+ print(
88
+ f"{indent}{' ' if is_last_item else '│ '}{zip_item}"
89
+ )
90
+ else:
91
+ print(f"{indent}{'└──' if is_last_item else '├──'} {item}")
92
+
93
+ if os.path.isdir(item_path):
94
+ new_indent = indent + (" " if is_last_item else "│ ")
95
+ print_tree_directory(item_path, new_indent)
96
+
97
+
98
+ def upload_model_list():
99
+ weight_root = "weights"
100
+ models = []
101
+ for name in os.listdir(weight_root):
102
+ if name.endswith(".pth"):
103
+ models.append("weights/" + name)
104
+ if models:
105
+ logger.debug(models)
106
+
107
+ index_root = "logs"
108
+ index_paths = [None]
109
+ for name in os.listdir(index_root):
110
+ if name.endswith(".index"):
111
+ index_paths.append("logs/" + name)
112
+ if index_paths:
113
+ logger.debug(index_paths)
114
+
115
+ return models, index_paths
116
+
117
+
118
+ def manual_download(url, dst):
119
+ if "drive.google" in url:
120
+ logger.info("Drive url")
121
+ if "folders" in url:
122
+ logger.info("folder")
123
+ os.system(f'gdown --folder "{url}" -O {dst} --fuzzy -c')
124
+ else:
125
+ logger.info("single")
126
+ os.system(f'gdown "{url}" -O {dst} --fuzzy -c')
127
+ elif "huggingface" in url:
128
+ logger.info("HuggingFace url")
129
+ if "/blob/" in url or "/resolve/" in url:
130
+ if "/blob/" in url:
131
+ url = url.replace("/blob/", "/resolve/")
132
+ download_manager(url=url, path=dst, overwrite=True, progress=True)
133
+ else:
134
+ os.system(f"git clone {url} {dst+'repo/'}")
135
+ elif "http" in url:
136
+ logger.info("URL")
137
+ download_manager(url=url, path=dst, overwrite=True, progress=True)
138
+ elif os.path.exists(url):
139
+ logger.info("Path")
140
+ copy_files(url, dst)
141
+ else:
142
+ logger.error(f"No valid URL: {url}")
143
+
144
+
145
+ def download_list(text_downloads):
146
+
147
+ if os.environ.get("ZERO_GPU") == "TRUE":
148
+ raise RuntimeError("This option is disabled in this demo.")
149
+
150
+ try:
151
+ urls = [elem.strip() for elem in text_downloads.split(",")]
152
+ except Exception as error:
153
+ raise ValueError(f"No valid URL. {str(error)}")
154
+
155
+ create_directories(["downloads", "logs", "weights"])
156
+
157
+ path_download = "downloads/"
158
+ for url in urls:
159
+ manual_download(url, path_download)
160
+
161
+ # Tree
162
+ print("####################################")
163
+ print_tree_directory("downloads", indent="")
164
+ print("####################################")
165
+
166
+ # Place files
167
+ select_zip_and_rar_files("downloads/")
168
+
169
+ models, _ = upload_model_list()
170
+
171
+ # hf space models files delete
172
+ remove_directory_contents("downloads/repo")
173
+
174
+ return f"Downloaded = {models}"
175
+
176
+
177
+ def select_zip_and_rar_files(directory_path="downloads/"):
178
+ # filter
179
+ zip_files = []
180
+ rar_files = []
181
+
182
+ for file_name in os.listdir(directory_path):
183
+ if file_name.endswith(".zip"):
184
+ zip_files.append(file_name)
185
+ elif file_name.endswith(".rar"):
186
+ rar_files.append(file_name)
187
+
188
+ # extract
189
+ for file_name in zip_files:
190
+ file_path = os.path.join(directory_path, file_name)
191
+ with zipfile.ZipFile(file_path, "r") as zip_ref:
192
+ zip_ref.extractall(directory_path)
193
+
194
+ for file_name in rar_files:
195
+ file_path = os.path.join(directory_path, file_name)
196
+ with rarfile.RarFile(file_path, "r") as rar_ref:
197
+ rar_ref.extractall(directory_path)
198
+
199
+ # set in path
200
+ def move_files_with_extension(src_dir, extension, destination_dir):
201
+ for root, _, files in os.walk(src_dir):
202
+ for file_name in files:
203
+ if file_name.endswith(extension):
204
+ source_file = os.path.join(root, file_name)
205
+ destination = os.path.join(destination_dir, file_name)
206
+ shutil.move(source_file, destination)
207
+
208
+ move_files_with_extension(directory_path, ".index", "logs/")
209
+ move_files_with_extension(directory_path, ".pth", "weights/")
210
+
211
+ return "Download complete"
212
+
213
+
214
+ def is_file_with_extensions(string_path, extensions):
215
+ return any(string_path.lower().endswith(ext) for ext in extensions)
216
+
217
+
218
+ def is_video_file(string_path):
219
+ return is_file_with_extensions(string_path, VIDEO_EXTENSIONS)
220
+
221
+
222
+ def is_audio_file(string_path):
223
+ return is_file_with_extensions(string_path, AUDIO_EXTENSIONS)
224
+
225
+
226
+ def is_subtitle_file(string_path):
227
+ return is_file_with_extensions(string_path, SUBTITLE_EXTENSIONS)
228
+
229
+
230
+ def get_directory_files(directory):
231
+ audio_files = []
232
+ video_files = []
233
+ sub_files = []
234
+
235
+ for item in os.listdir(directory):
236
+ item_path = os.path.join(directory, item)
237
+
238
+ if os.path.isfile(item_path):
239
+
240
+ if is_audio_file(item_path):
241
+ audio_files.append(item_path)
242
+
243
+ elif is_video_file(item_path):
244
+ video_files.append(item_path)
245
+
246
+ elif is_subtitle_file(item_path):
247
+ sub_files.append(item_path)
248
+
249
+ logger.info(
250
+ f"Files in path ({directory}): "
251
+ f"{str(audio_files + video_files + sub_files)}"
252
+ )
253
+
254
+ return audio_files, video_files, sub_files
255
+
256
+
257
+ def get_valid_files(paths):
258
+ valid_paths = []
259
+ for path in paths:
260
+ if os.path.isdir(path):
261
+ audio_files, video_files, sub_files = get_directory_files(path)
262
+ valid_paths.extend(audio_files)
263
+ valid_paths.extend(video_files)
264
+ valid_paths.extend(sub_files)
265
+ else:
266
+ valid_paths.append(path)
267
+
268
+ return valid_paths
269
+
270
+
271
+ def extract_video_links(link):
272
+
273
+ params_dlp = {"quiet": False, "no_warnings": True, "noplaylist": False}
274
+
275
+ try:
276
+ from yt_dlp import YoutubeDL
277
+ with capture.capture_output() as cap:
278
+ with YoutubeDL(params_dlp) as ydl:
279
+ info_dict = ydl.extract_info( # noqa
280
+ link, download=False, process=True
281
+ )
282
+
283
+ urls = re.findall(r'\[youtube\] Extracting URL: (.*?)\n', cap.stdout)
284
+ logger.info(f"List of videos in ({link}): {str(urls)}")
285
+ del cap
286
+ except Exception as error:
287
+ logger.error(f"{link} >> {str(error)}")
288
+ urls = [link]
289
+
290
+ return urls
291
+
292
+
293
+ def get_link_list(urls):
294
+ valid_links = []
295
+ for url_video in urls:
296
+ if "youtube.com" in url_video and "/watch?v=" not in url_video:
297
+ url_links = extract_video_links(url_video)
298
+ valid_links.extend(url_links)
299
+ else:
300
+ valid_links.append(url_video)
301
+ return valid_links
302
+
303
+ # =====================================
304
+ # Download Manager
305
+ # =====================================
306
+
307
+
308
+ def load_file_from_url(
309
+ url: str,
310
+ model_dir: str,
311
+ file_name: str | None = None,
312
+ overwrite: bool = False,
313
+ progress: bool = True,
314
+ ) -> str:
315
+ """Download a file from `url` into `model_dir`,
316
+ using the file present if possible.
317
+
318
+ Returns the path to the downloaded file.
319
+ """
320
+ os.makedirs(model_dir, exist_ok=True)
321
+ if not file_name:
322
+ parts = urlparse(url)
323
+ file_name = os.path.basename(parts.path)
324
+ cached_file = os.path.abspath(os.path.join(model_dir, file_name))
325
+
326
+ # Overwrite
327
+ if os.path.exists(cached_file):
328
+ if overwrite or os.path.getsize(cached_file) == 0:
329
+ remove_files(cached_file)
330
+
331
+ # Download
332
+ if not os.path.exists(cached_file):
333
+ logger.info(f'Downloading: "{url}" to {cached_file}\n')
334
+ from torch.hub import download_url_to_file
335
+
336
+ download_url_to_file(url, cached_file, progress=progress)
337
+ else:
338
+ logger.debug(cached_file)
339
+
340
+ return cached_file
341
+
342
+
343
+ def friendly_name(file: str):
344
+ if file.startswith("http"):
345
+ file = urlparse(file).path
346
+
347
+ file = os.path.basename(file)
348
+ model_name, extension = os.path.splitext(file)
349
+ return model_name, extension
350
+
351
+
352
+ def download_manager(
353
+ url: str,
354
+ path: str,
355
+ extension: str = "",
356
+ overwrite: bool = False,
357
+ progress: bool = True,
358
+ ):
359
+ url = url.strip()
360
+
361
+ name, ext = friendly_name(url)
362
+ name += ext if not extension else f".{extension}"
363
+
364
+ if url.startswith("http"):
365
+ filename = load_file_from_url(
366
+ url=url,
367
+ model_dir=path,
368
+ file_name=name,
369
+ overwrite=overwrite,
370
+ progress=progress,
371
+ )
372
+ else:
373
+ filename = path
374
+
375
+ return filename
376
+
377
+
378
+ # =====================================
379
+ # File management
380
+ # =====================================
381
+
382
+
383
+ # only remove files
384
+ def remove_files(file_list):
385
+ if isinstance(file_list, str):
386
+ file_list = [file_list]
387
+
388
+ for file in file_list:
389
+ if os.path.exists(file):
390
+ os.remove(file)
391
+
392
+
393
+ def remove_directory_contents(directory_path):
394
+ """
395
+ Removes all files and subdirectories within a directory.
396
+
397
+ Parameters:
398
+ directory_path (str): Path to the directory whose
399
+ contents need to be removed.
400
+ """
401
+ if os.path.exists(directory_path):
402
+ for filename in os.listdir(directory_path):
403
+ file_path = os.path.join(directory_path, filename)
404
+ try:
405
+ if os.path.isfile(file_path):
406
+ os.remove(file_path)
407
+ elif os.path.isdir(file_path):
408
+ shutil.rmtree(file_path)
409
+ except Exception as e:
410
+ logger.error(f"Failed to delete {file_path}. Reason: {e}")
411
+ logger.info(f"Content in '{directory_path}' removed.")
412
+ else:
413
+ logger.error(f"Directory '{directory_path}' does not exist.")
414
+
415
+
416
+ # Create directory if not exists
417
+ def create_directories(directory_path):
418
+ if isinstance(directory_path, str):
419
+ directory_path = [directory_path]
420
+ for one_dir_path in directory_path:
421
+ if not os.path.exists(one_dir_path):
422
+ os.makedirs(one_dir_path)
423
+ logger.debug(f"Directory '{one_dir_path}' created.")
424
+
425
+
426
+ def move_files(source_dir, destination_dir, extension=""):
427
+ """
428
+ Moves file(s) from the source path to the destination path.
429
+
430
+ Parameters:
431
+ source_dir (str): Path to the source directory.
432
+ destination_dir (str): Path to the destination directory.
433
+ extension (str): Only move files with this extension.
434
+ """
435
+ create_directories(destination_dir)
436
+
437
+ for filename in os.listdir(source_dir):
438
+ source_path = os.path.join(source_dir, filename)
439
+ destination_path = os.path.join(destination_dir, filename)
440
+ if extension and not filename.endswith(extension):
441
+ continue
442
+ os.replace(source_path, destination_path)
443
+
444
+
445
+ def copy_files(source_path, destination_path):
446
+ """
447
+ Copies a file or multiple files from a source path to a destination path.
448
+
449
+ Parameters:
450
+ source_path (str or list): Path or list of paths to the source
451
+ file(s) or directory.
452
+ destination_path (str): Path to the destination directory.
453
+ """
454
+ create_directories(destination_path)
455
+
456
+ if isinstance(source_path, str):
457
+ source_path = [source_path]
458
+
459
+ if os.path.isdir(source_path[0]):
460
+ # Copy all files from the source directory to the destination directory
461
+ base_path = source_path[0]
462
+ source_path = os.listdir(source_path[0])
463
+ source_path = [
464
+ os.path.join(base_path, file_name) for file_name in source_path
465
+ ]
466
+
467
+ for one_source_path in source_path:
468
+ if os.path.exists(one_source_path):
469
+ shutil.copy2(one_source_path, destination_path)
470
+ logger.debug(
471
+ f"File '{one_source_path}' copied to '{destination_path}'."
472
+ )
473
+ else:
474
+ logger.error(f"File '{one_source_path}' does not exist.")
475
+
476
+
477
+ def rename_file(current_name, new_name):
478
+ file_directory = os.path.dirname(current_name)
479
+
480
+ if os.path.exists(current_name):
481
+ dir_new_name_file = os.path.join(file_directory, new_name)
482
+ os.rename(current_name, dir_new_name_file)
483
+ logger.debug(f"File '{current_name}' renamed to '{new_name}'.")
484
+ return dir_new_name_file
485
+ else:
486
+ logger.error(f"File '{current_name}' does not exist.")
487
+ return None
vci_pipeline.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np, parselmouth, torch, pdb, sys
2
+ from time import time as ttime
3
+ import torch.nn.functional as F
4
+ import scipy.signal as signal
5
+ import pyworld, os, traceback, faiss, librosa, torchcrepe
6
+ from scipy import signal
7
+ from functools import lru_cache
8
+ from soni_translate.logging_setup import logger
9
+
10
+ now_dir = os.getcwd()
11
+ sys.path.append(now_dir)
12
+
13
+ bh, ah = signal.butter(N=5, Wn=48, btype="high", fs=16000)
14
+
15
+ input_audio_path2wav = {}
16
+
17
+
18
+ @lru_cache
19
+ def cache_harvest_f0(input_audio_path, fs, f0max, f0min, frame_period):
20
+ audio = input_audio_path2wav[input_audio_path]
21
+ f0, t = pyworld.harvest(
22
+ audio,
23
+ fs=fs,
24
+ f0_ceil=f0max,
25
+ f0_floor=f0min,
26
+ frame_period=frame_period,
27
+ )
28
+ f0 = pyworld.stonemask(audio, f0, t, fs)
29
+ return f0
30
+
31
+
32
+ def change_rms(data1, sr1, data2, sr2, rate): # 1 is the input audio, 2 is the output audio, rate is the proportion of 2
33
+ # print(data1.max(),data2.max())
34
+ rms1 = librosa.feature.rms(
35
+ y=data1, frame_length=sr1 // 2 * 2, hop_length=sr1 // 2
36
+ ) # one dot every half second
37
+ rms2 = librosa.feature.rms(y=data2, frame_length=sr2 // 2 * 2, hop_length=sr2 // 2)
38
+ rms1 = torch.from_numpy(rms1)
39
+ rms1 = F.interpolate(
40
+ rms1.unsqueeze(0), size=data2.shape[0], mode="linear"
41
+ ).squeeze()
42
+ rms2 = torch.from_numpy(rms2)
43
+ rms2 = F.interpolate(
44
+ rms2.unsqueeze(0), size=data2.shape[0], mode="linear"
45
+ ).squeeze()
46
+ rms2 = torch.max(rms2, torch.zeros_like(rms2) + 1e-6)
47
+ data2 *= (
48
+ torch.pow(rms1, torch.tensor(1 - rate))
49
+ * torch.pow(rms2, torch.tensor(rate - 1))
50
+ ).numpy()
51
+ return data2
52
+
53
+
54
+ class VC(object):
55
+ def __init__(self, tgt_sr, config):
56
+ self.x_pad, self.x_query, self.x_center, self.x_max, self.is_half = (
57
+ config.x_pad,
58
+ config.x_query,
59
+ config.x_center,
60
+ config.x_max,
61
+ config.is_half,
62
+ )
63
+ self.sr = 16000 # hubert input sampling rate
64
+ self.window = 160 # points per frame
65
+ self.t_pad = self.sr * self.x_pad # Pad time before and after each bar
66
+ self.t_pad_tgt = tgt_sr * self.x_pad
67
+ self.t_pad2 = self.t_pad * 2
68
+ self.t_query = self.sr * self.x_query # Query time before and after the cut point
69
+ self.t_center = self.sr * self.x_center # Query point cut position
70
+ self.t_max = self.sr * self.x_max # Query-free duration threshold
71
+ self.device = config.device
72
+
73
+ def get_f0(
74
+ self,
75
+ input_audio_path,
76
+ x,
77
+ p_len,
78
+ f0_up_key,
79
+ f0_method,
80
+ filter_radius,
81
+ inp_f0=None,
82
+ ):
83
+ global input_audio_path2wav
84
+ time_step = self.window / self.sr * 1000
85
+ f0_min = 50
86
+ f0_max = 1100
87
+ f0_mel_min = 1127 * np.log(1 + f0_min / 700)
88
+ f0_mel_max = 1127 * np.log(1 + f0_max / 700)
89
+ if f0_method == "pm":
90
+ f0 = (
91
+ parselmouth.Sound(x, self.sr)
92
+ .to_pitch_ac(
93
+ time_step=time_step / 1000,
94
+ voicing_threshold=0.6,
95
+ pitch_floor=f0_min,
96
+ pitch_ceiling=f0_max,
97
+ )
98
+ .selected_array["frequency"]
99
+ )
100
+ pad_size = (p_len - len(f0) + 1) // 2
101
+ if pad_size > 0 or p_len - len(f0) - pad_size > 0:
102
+ f0 = np.pad(
103
+ f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant"
104
+ )
105
+ elif f0_method == "harvest":
106
+ input_audio_path2wav[input_audio_path] = x.astype(np.double)
107
+ f0 = cache_harvest_f0(input_audio_path, self.sr, f0_max, f0_min, 10)
108
+ if filter_radius > 2:
109
+ f0 = signal.medfilt(f0, 3)
110
+ elif f0_method == "crepe":
111
+ model = "full"
112
+ # Pick a batch size that doesn't cause memory errors on your gpu
113
+ batch_size = 512
114
+ # Compute pitch using first gpu
115
+ audio = torch.tensor(np.copy(x))[None].float()
116
+ f0, pd = torchcrepe.predict(
117
+ audio,
118
+ self.sr,
119
+ self.window,
120
+ f0_min,
121
+ f0_max,
122
+ model,
123
+ batch_size=batch_size,
124
+ device=self.device,
125
+ return_periodicity=True,
126
+ )
127
+ pd = torchcrepe.filter.median(pd, 3)
128
+ f0 = torchcrepe.filter.mean(f0, 3)
129
+ f0[pd < 0.1] = 0
130
+ f0 = f0[0].cpu().numpy()
131
+ elif "rmvpe" in f0_method:
132
+ if hasattr(self, "model_rmvpe") == False:
133
+ from lib.rmvpe import RMVPE
134
+
135
+ logger.info("Loading vocal pitch estimator model")
136
+ self.model_rmvpe = RMVPE(
137
+ "rmvpe.pt", is_half=self.is_half, device=self.device
138
+ )
139
+ thred = 0.03
140
+ if "+" in f0_method:
141
+ f0 = self.model_rmvpe.pitch_based_audio_inference(x, thred, f0_min, f0_max)
142
+ else:
143
+ f0 = self.model_rmvpe.infer_from_audio(x, thred)
144
+
145
+ f0 *= pow(2, f0_up_key / 12)
146
+ # with open("test.txt","w")as f:f.write("\n".join([str(i)for i in f0.tolist()]))
147
+ tf0 = self.sr // self.window # f0 points per second
148
+ if inp_f0 is not None:
149
+ delta_t = np.round(
150
+ (inp_f0[:, 0].max() - inp_f0[:, 0].min()) * tf0 + 1
151
+ ).astype("int16")
152
+ replace_f0 = np.interp(
153
+ list(range(delta_t)), inp_f0[:, 0] * 100, inp_f0[:, 1]
154
+ )
155
+ shape = f0[self.x_pad * tf0 : self.x_pad * tf0 + len(replace_f0)].shape[0]
156
+ f0[self.x_pad * tf0 : self.x_pad * tf0 + len(replace_f0)] = replace_f0[
157
+ :shape
158
+ ]
159
+ # with open("test_opt.txt","w")as f:f.write("\n".join([str(i)for i in f0.tolist()]))
160
+ f0bak = f0.copy()
161
+ f0_mel = 1127 * np.log(1 + f0 / 700)
162
+ f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * 254 / (
163
+ f0_mel_max - f0_mel_min
164
+ ) + 1
165
+ f0_mel[f0_mel <= 1] = 1
166
+ f0_mel[f0_mel > 255] = 255
167
+ try:
168
+ f0_coarse = np.rint(f0_mel).astype(np.int)
169
+ except: # noqa
170
+ f0_coarse = np.rint(f0_mel).astype(int)
171
+ return f0_coarse, f0bak # 1-0
172
+
173
+ def vc(
174
+ self,
175
+ model,
176
+ net_g,
177
+ sid,
178
+ audio0,
179
+ pitch,
180
+ pitchf,
181
+ times,
182
+ index,
183
+ big_npy,
184
+ index_rate,
185
+ version,
186
+ protect,
187
+ ): # ,file_index,file_big_npy
188
+ feats = torch.from_numpy(audio0)
189
+ if self.is_half:
190
+ feats = feats.half()
191
+ else:
192
+ feats = feats.float()
193
+ if feats.dim() == 2: # double channels
194
+ feats = feats.mean(-1)
195
+ assert feats.dim() == 1, feats.dim()
196
+ feats = feats.view(1, -1)
197
+ padding_mask = torch.BoolTensor(feats.shape).to(self.device).fill_(False)
198
+
199
+ inputs = {
200
+ "source": feats.to(self.device),
201
+ "padding_mask": padding_mask,
202
+ "output_layer": 9 if version == "v1" else 12,
203
+ }
204
+ t0 = ttime()
205
+ with torch.no_grad():
206
+ logits = model.extract_features(**inputs)
207
+ feats = model.final_proj(logits[0]) if version == "v1" else logits[0]
208
+ if protect < 0.5 and pitch != None and pitchf != None:
209
+ feats0 = feats.clone()
210
+ if (
211
+ isinstance(index, type(None)) == False
212
+ and isinstance(big_npy, type(None)) == False
213
+ and index_rate != 0
214
+ ):
215
+ npy = feats[0].cpu().numpy()
216
+ if self.is_half:
217
+ npy = npy.astype("float32")
218
+
219
+ # _, I = index.search(npy, 1)
220
+ # npy = big_npy[I.squeeze()]
221
+
222
+ score, ix = index.search(npy, k=8)
223
+ weight = np.square(1 / score)
224
+ weight /= weight.sum(axis=1, keepdims=True)
225
+ npy = np.sum(big_npy[ix] * np.expand_dims(weight, axis=2), axis=1)
226
+
227
+ if self.is_half:
228
+ npy = npy.astype("float16")
229
+ feats = (
230
+ torch.from_numpy(npy).unsqueeze(0).to(self.device) * index_rate
231
+ + (1 - index_rate) * feats
232
+ )
233
+
234
+ feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
235
+ if protect < 0.5 and pitch != None and pitchf != None:
236
+ feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(
237
+ 0, 2, 1
238
+ )
239
+ t1 = ttime()
240
+ p_len = audio0.shape[0] // self.window
241
+ if feats.shape[1] < p_len:
242
+ p_len = feats.shape[1]
243
+ if pitch != None and pitchf != None:
244
+ pitch = pitch[:, :p_len]
245
+ pitchf = pitchf[:, :p_len]
246
+
247
+ if protect < 0.5 and pitch != None and pitchf != None:
248
+ pitchff = pitchf.clone()
249
+ pitchff[pitchf > 0] = 1
250
+ pitchff[pitchf < 1] = protect
251
+ pitchff = pitchff.unsqueeze(-1)
252
+ feats = feats * pitchff + feats0 * (1 - pitchff)
253
+ feats = feats.to(feats0.dtype)
254
+ p_len = torch.tensor([p_len], device=self.device).long()
255
+ with torch.no_grad():
256
+ if pitch != None and pitchf != None:
257
+ audio1 = (
258
+ (net_g.infer(feats, p_len, pitch, pitchf, sid)[0][0, 0])
259
+ .data.cpu()
260
+ .float()
261
+ .numpy()
262
+ )
263
+ else:
264
+ audio1 = (
265
+ (net_g.infer(feats, p_len, sid)[0][0, 0]).data.cpu().float().numpy()
266
+ )
267
+ del feats, p_len, padding_mask
268
+ if torch.cuda.is_available():
269
+ torch.cuda.empty_cache()
270
+ t2 = ttime()
271
+ times[0] += t1 - t0
272
+ times[2] += t2 - t1
273
+ return audio1
274
+
275
+ def pipeline(
276
+ self,
277
+ model,
278
+ net_g,
279
+ sid,
280
+ audio,
281
+ input_audio_path,
282
+ times,
283
+ f0_up_key,
284
+ f0_method,
285
+ file_index,
286
+ # file_big_npy,
287
+ index_rate,
288
+ if_f0,
289
+ filter_radius,
290
+ tgt_sr,
291
+ resample_sr,
292
+ rms_mix_rate,
293
+ version,
294
+ protect,
295
+ f0_file=None,
296
+ ):
297
+ if (
298
+ file_index != ""
299
+ # and file_big_npy != ""
300
+ # and os.path.exists(file_big_npy) == True
301
+ and os.path.exists(file_index) == True
302
+ and index_rate != 0
303
+ ):
304
+ try:
305
+ index = faiss.read_index(file_index)
306
+ # big_npy = np.load(file_big_npy)
307
+ big_npy = index.reconstruct_n(0, index.ntotal)
308
+ except:
309
+ traceback.print_exc()
310
+ index = big_npy = None
311
+ else:
312
+ index = big_npy = None
313
+ logger.warning("File index Not found, set None")
314
+
315
+ audio = signal.filtfilt(bh, ah, audio)
316
+ audio_pad = np.pad(audio, (self.window // 2, self.window // 2), mode="reflect")
317
+ opt_ts = []
318
+ if audio_pad.shape[0] > self.t_max:
319
+ audio_sum = np.zeros_like(audio)
320
+ for i in range(self.window):
321
+ audio_sum += audio_pad[i : i - self.window]
322
+ for t in range(self.t_center, audio.shape[0], self.t_center):
323
+ opt_ts.append(
324
+ t
325
+ - self.t_query
326
+ + np.where(
327
+ np.abs(audio_sum[t - self.t_query : t + self.t_query])
328
+ == np.abs(audio_sum[t - self.t_query : t + self.t_query]).min()
329
+ )[0][0]
330
+ )
331
+ s = 0
332
+ audio_opt = []
333
+ t = None
334
+ t1 = ttime()
335
+ audio_pad = np.pad(audio, (self.t_pad, self.t_pad), mode="reflect")
336
+ p_len = audio_pad.shape[0] // self.window
337
+ inp_f0 = None
338
+ if hasattr(f0_file, "name") == True:
339
+ try:
340
+ with open(f0_file.name, "r") as f:
341
+ lines = f.read().strip("\n").split("\n")
342
+ inp_f0 = []
343
+ for line in lines:
344
+ inp_f0.append([float(i) for i in line.split(",")])
345
+ inp_f0 = np.array(inp_f0, dtype="float32")
346
+ except:
347
+ traceback.print_exc()
348
+ sid = torch.tensor(sid, device=self.device).unsqueeze(0).long()
349
+ pitch, pitchf = None, None
350
+ if if_f0 == 1:
351
+ pitch, pitchf = self.get_f0(
352
+ input_audio_path,
353
+ audio_pad,
354
+ p_len,
355
+ f0_up_key,
356
+ f0_method,
357
+ filter_radius,
358
+ inp_f0,
359
+ )
360
+ pitch = pitch[:p_len]
361
+ pitchf = pitchf[:p_len]
362
+ if self.device == "mps":
363
+ pitchf = pitchf.astype(np.float32)
364
+ pitch = torch.tensor(pitch, device=self.device).unsqueeze(0).long()
365
+ pitchf = torch.tensor(pitchf, device=self.device).unsqueeze(0).float()
366
+ t2 = ttime()
367
+ times[1] += t2 - t1
368
+ for t in opt_ts:
369
+ t = t // self.window * self.window
370
+ if if_f0 == 1:
371
+ audio_opt.append(
372
+ self.vc(
373
+ model,
374
+ net_g,
375
+ sid,
376
+ audio_pad[s : t + self.t_pad2 + self.window],
377
+ pitch[:, s // self.window : (t + self.t_pad2) // self.window],
378
+ pitchf[:, s // self.window : (t + self.t_pad2) // self.window],
379
+ times,
380
+ index,
381
+ big_npy,
382
+ index_rate,
383
+ version,
384
+ protect,
385
+ )[self.t_pad_tgt : -self.t_pad_tgt]
386
+ )
387
+ else:
388
+ audio_opt.append(
389
+ self.vc(
390
+ model,
391
+ net_g,
392
+ sid,
393
+ audio_pad[s : t + self.t_pad2 + self.window],
394
+ None,
395
+ None,
396
+ times,
397
+ index,
398
+ big_npy,
399
+ index_rate,
400
+ version,
401
+ protect,
402
+ )[self.t_pad_tgt : -self.t_pad_tgt]
403
+ )
404
+ s = t
405
+ if if_f0 == 1:
406
+ audio_opt.append(
407
+ self.vc(
408
+ model,
409
+ net_g,
410
+ sid,
411
+ audio_pad[t:],
412
+ pitch[:, t // self.window :] if t is not None else pitch,
413
+ pitchf[:, t // self.window :] if t is not None else pitchf,
414
+ times,
415
+ index,
416
+ big_npy,
417
+ index_rate,
418
+ version,
419
+ protect,
420
+ )[self.t_pad_tgt : -self.t_pad_tgt]
421
+ )
422
+ else:
423
+ audio_opt.append(
424
+ self.vc(
425
+ model,
426
+ net_g,
427
+ sid,
428
+ audio_pad[t:],
429
+ None,
430
+ None,
431
+ times,
432
+ index,
433
+ big_npy,
434
+ index_rate,
435
+ version,
436
+ protect,
437
+ )[self.t_pad_tgt : -self.t_pad_tgt]
438
+ )
439
+ audio_opt = np.concatenate(audio_opt)
440
+ if rms_mix_rate != 1:
441
+ audio_opt = change_rms(audio, 16000, audio_opt, tgt_sr, rms_mix_rate)
442
+ if resample_sr >= 16000 and tgt_sr != resample_sr:
443
+ audio_opt = librosa.resample(
444
+ audio_opt, orig_sr=tgt_sr, target_sr=resample_sr
445
+ )
446
+ audio_max = np.abs(audio_opt).max() / 0.99
447
+ max_int16 = 32768
448
+ if audio_max > 1:
449
+ max_int16 /= audio_max
450
+ audio_opt = (audio_opt * max_int16).astype(np.int16)
451
+ del pitch, pitchf, sid
452
+ if torch.cuda.is_available():
453
+ torch.cuda.empty_cache()
454
+ return audio_opt
voice_main.py ADDED
@@ -0,0 +1,732 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from soni_translate.logging_setup import logger
2
+ import torch
3
+ import gc
4
+ import numpy as np
5
+ import os
6
+ import shutil
7
+ import warnings
8
+ import threading
9
+ from tqdm import tqdm
10
+ from lib.infer_pack.models import (
11
+ SynthesizerTrnMs256NSFsid,
12
+ SynthesizerTrnMs256NSFsid_nono,
13
+ SynthesizerTrnMs768NSFsid,
14
+ SynthesizerTrnMs768NSFsid_nono,
15
+ )
16
+ from lib.audio import load_audio
17
+ import soundfile as sf
18
+ import edge_tts
19
+ import asyncio
20
+ from soni_translate.utils import remove_directory_contents, create_directories
21
+ from scipy import signal
22
+ from time import time as ttime
23
+ import faiss
24
+ from vci_pipeline import VC, change_rms, bh, ah
25
+ import librosa
26
+
27
+ warnings.filterwarnings("ignore")
28
+
29
+
30
+ class Config:
31
+ def __init__(self, only_cpu=False):
32
+ self.device = "cuda:0"
33
+ self.is_half = True
34
+ self.n_cpu = 0
35
+ self.gpu_name = None
36
+ self.gpu_mem = None
37
+ (
38
+ self.x_pad,
39
+ self.x_query,
40
+ self.x_center,
41
+ self.x_max
42
+ ) = self.device_config(only_cpu)
43
+
44
+ def device_config(self, only_cpu) -> tuple:
45
+ if torch.cuda.is_available() and not only_cpu:
46
+ i_device = int(self.device.split(":")[-1])
47
+ self.gpu_name = torch.cuda.get_device_name(i_device)
48
+ if (
49
+ ("16" in self.gpu_name and "V100" not in self.gpu_name.upper())
50
+ or "P40" in self.gpu_name.upper()
51
+ or "1060" in self.gpu_name
52
+ or "1070" in self.gpu_name
53
+ or "1080" in self.gpu_name
54
+ ):
55
+ logger.info(
56
+ "16/10 Series GPUs and P40 excel "
57
+ "in single-precision tasks."
58
+ )
59
+ self.is_half = False
60
+ else:
61
+ self.gpu_name = None
62
+ self.gpu_mem = int(
63
+ torch.cuda.get_device_properties(i_device).total_memory
64
+ / 1024
65
+ / 1024
66
+ / 1024
67
+ + 0.4
68
+ )
69
+ elif torch.backends.mps.is_available() and not only_cpu:
70
+ logger.info("Supported N-card not found, using MPS for inference")
71
+ self.device = "mps"
72
+ else:
73
+ logger.info("No supported N-card found, using CPU for inference")
74
+ self.device = "cpu"
75
+ self.is_half = False
76
+
77
+ if self.n_cpu == 0:
78
+ self.n_cpu = os.cpu_count()
79
+
80
+ if self.is_half:
81
+ # 6GB VRAM configuration
82
+ x_pad = 3
83
+ x_query = 10
84
+ x_center = 60
85
+ x_max = 65
86
+ else:
87
+ # 5GB VRAM configuration
88
+ x_pad = 1
89
+ x_query = 6
90
+ x_center = 38
91
+ x_max = 41
92
+
93
+ if self.gpu_mem is not None and self.gpu_mem <= 4:
94
+ x_pad = 1
95
+ x_query = 5
96
+ x_center = 30
97
+ x_max = 32
98
+
99
+ logger.info(
100
+ f"Config: Device is {self.device}, "
101
+ f"half precision is {self.is_half}"
102
+ )
103
+
104
+ return x_pad, x_query, x_center, x_max
105
+
106
+
107
+ BASE_DOWNLOAD_LINK = "https://huggingface.co/r3gm/sonitranslate_voice_models/resolve/main/"
108
+ BASE_MODELS = [
109
+ "hubert_base.pt",
110
+ "rmvpe.pt"
111
+ ]
112
+ BASE_DIR = "."
113
+
114
+
115
+ def load_hu_bert(config):
116
+ from fairseq import checkpoint_utils
117
+ from soni_translate.utils import download_manager
118
+
119
+ for id_model in BASE_MODELS:
120
+ download_manager(
121
+ os.path.join(BASE_DOWNLOAD_LINK, id_model), BASE_DIR
122
+ )
123
+
124
+ models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
125
+ ["hubert_base.pt"],
126
+ suffix="",
127
+ )
128
+ hubert_model = models[0]
129
+ hubert_model = hubert_model.to(config.device)
130
+ if config.is_half:
131
+ hubert_model = hubert_model.half()
132
+ else:
133
+ hubert_model = hubert_model.float()
134
+ hubert_model.eval()
135
+
136
+ return hubert_model
137
+
138
+
139
+ def load_trained_model(model_path, config):
140
+
141
+ if not model_path:
142
+ raise ValueError("No model found")
143
+
144
+ logger.info("Loading %s" % model_path)
145
+ cpt = torch.load(model_path, map_location="cpu")
146
+ tgt_sr = cpt["config"][-1]
147
+ cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
148
+ if_f0 = cpt.get("f0", 1)
149
+ if if_f0 == 0:
150
+ # protect to 0.5 need?
151
+ pass
152
+
153
+ version = cpt.get("version", "v1")
154
+ if version == "v1":
155
+ if if_f0 == 1:
156
+ net_g = SynthesizerTrnMs256NSFsid(
157
+ *cpt["config"], is_half=config.is_half
158
+ )
159
+ else:
160
+ net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
161
+ elif version == "v2":
162
+ if if_f0 == 1:
163
+ net_g = SynthesizerTrnMs768NSFsid(
164
+ *cpt["config"], is_half=config.is_half
165
+ )
166
+ else:
167
+ net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
168
+ del net_g.enc_q
169
+
170
+ net_g.load_state_dict(cpt["weight"], strict=False)
171
+ net_g.eval().to(config.device)
172
+
173
+ if config.is_half:
174
+ net_g = net_g.half()
175
+ else:
176
+ net_g = net_g.float()
177
+
178
+ vc = VC(tgt_sr, config)
179
+ n_spk = cpt["config"][-3]
180
+
181
+ return n_spk, tgt_sr, net_g, vc, cpt, version
182
+
183
+
184
+ class ClassVoices:
185
+ def __init__(self, only_cpu=False):
186
+ self.model_config = {}
187
+ self.config = None
188
+ self.only_cpu = only_cpu
189
+
190
+ def apply_conf(
191
+ self,
192
+ tag="base_model",
193
+ file_model="",
194
+ pitch_algo="pm",
195
+ pitch_lvl=0,
196
+ file_index="",
197
+ index_influence=0.66,
198
+ respiration_median_filtering=3,
199
+ envelope_ratio=0.25,
200
+ consonant_breath_protection=0.33,
201
+ resample_sr=0,
202
+ file_pitch_algo="",
203
+ ):
204
+
205
+ if not file_model:
206
+ raise ValueError("Model not found")
207
+
208
+ if file_index is None:
209
+ file_index = ""
210
+
211
+ if file_pitch_algo is None:
212
+ file_pitch_algo = ""
213
+
214
+ if not self.config:
215
+ self.config = Config(self.only_cpu)
216
+ self.hu_bert_model = None
217
+ self.model_pitch_estimator = None
218
+
219
+ self.model_config[tag] = {
220
+ "file_model": file_model,
221
+ "pitch_algo": pitch_algo,
222
+ "pitch_lvl": pitch_lvl, # no decimal
223
+ "file_index": file_index,
224
+ "index_influence": index_influence,
225
+ "respiration_median_filtering": respiration_median_filtering,
226
+ "envelope_ratio": envelope_ratio,
227
+ "consonant_breath_protection": consonant_breath_protection,
228
+ "resample_sr": resample_sr,
229
+ "file_pitch_algo": file_pitch_algo,
230
+ }
231
+ return f"CONFIGURATION APPLIED FOR {tag}: {file_model}"
232
+
233
+ def infer(
234
+ self,
235
+ task_id,
236
+ params,
237
+ # load model
238
+ n_spk,
239
+ tgt_sr,
240
+ net_g,
241
+ pipe,
242
+ cpt,
243
+ version,
244
+ if_f0,
245
+ # load index
246
+ index_rate,
247
+ index,
248
+ big_npy,
249
+ # load f0 file
250
+ inp_f0,
251
+ # audio file
252
+ input_audio_path,
253
+ overwrite,
254
+ ):
255
+
256
+ f0_method = params["pitch_algo"]
257
+ f0_up_key = params["pitch_lvl"]
258
+ filter_radius = params["respiration_median_filtering"]
259
+ resample_sr = params["resample_sr"]
260
+ rms_mix_rate = params["envelope_ratio"]
261
+ protect = params["consonant_breath_protection"]
262
+
263
+ if not os.path.exists(input_audio_path):
264
+ raise ValueError(
265
+ "The audio file was not found or is not "
266
+ f"a valid file: {input_audio_path}"
267
+ )
268
+
269
+ f0_up_key = int(f0_up_key)
270
+
271
+ audio = load_audio(input_audio_path, 16000)
272
+
273
+ # Normalize audio
274
+ audio_max = np.abs(audio).max() / 0.95
275
+ if audio_max > 1:
276
+ audio /= audio_max
277
+
278
+ times = [0, 0, 0]
279
+
280
+ # filters audio signal, pads it, computes sliding window sums,
281
+ # and extracts optimized time indices
282
+ audio = signal.filtfilt(bh, ah, audio)
283
+ audio_pad = np.pad(
284
+ audio, (pipe.window // 2, pipe.window // 2), mode="reflect"
285
+ )
286
+ opt_ts = []
287
+ if audio_pad.shape[0] > pipe.t_max:
288
+ audio_sum = np.zeros_like(audio)
289
+ for i in range(pipe.window):
290
+ audio_sum += audio_pad[i:i - pipe.window]
291
+ for t in range(pipe.t_center, audio.shape[0], pipe.t_center):
292
+ opt_ts.append(
293
+ t
294
+ - pipe.t_query
295
+ + np.where(
296
+ np.abs(audio_sum[t - pipe.t_query: t + pipe.t_query])
297
+ == np.abs(audio_sum[t - pipe.t_query: t + pipe.t_query]).min()
298
+ )[0][0]
299
+ )
300
+
301
+ s = 0
302
+ audio_opt = []
303
+ t = None
304
+ t1 = ttime()
305
+
306
+ sid_value = 0
307
+ sid = torch.tensor(sid_value, device=pipe.device).unsqueeze(0).long()
308
+
309
+ # Pads audio symmetrically, calculates length divided by window size.
310
+ audio_pad = np.pad(audio, (pipe.t_pad, pipe.t_pad), mode="reflect")
311
+ p_len = audio_pad.shape[0] // pipe.window
312
+
313
+ # Estimates pitch from audio signal
314
+ pitch, pitchf = None, None
315
+ if if_f0 == 1:
316
+ pitch, pitchf = pipe.get_f0(
317
+ input_audio_path,
318
+ audio_pad,
319
+ p_len,
320
+ f0_up_key,
321
+ f0_method,
322
+ filter_radius,
323
+ inp_f0,
324
+ )
325
+ pitch = pitch[:p_len]
326
+ pitchf = pitchf[:p_len]
327
+ if pipe.device == "mps":
328
+ pitchf = pitchf.astype(np.float32)
329
+ pitch = torch.tensor(
330
+ pitch, device=pipe.device
331
+ ).unsqueeze(0).long()
332
+ pitchf = torch.tensor(
333
+ pitchf, device=pipe.device
334
+ ).unsqueeze(0).float()
335
+
336
+ t2 = ttime()
337
+ times[1] += t2 - t1
338
+ for t in opt_ts:
339
+ t = t // pipe.window * pipe.window
340
+ if if_f0 == 1:
341
+ pitch_slice = pitch[
342
+ :, s // pipe.window: (t + pipe.t_pad2) // pipe.window
343
+ ]
344
+ pitchf_slice = pitchf[
345
+ :, s // pipe.window: (t + pipe.t_pad2) // pipe.window
346
+ ]
347
+ else:
348
+ pitch_slice = None
349
+ pitchf_slice = None
350
+
351
+ audio_slice = audio_pad[s:t + pipe.t_pad2 + pipe.window]
352
+ audio_opt.append(
353
+ pipe.vc(
354
+ self.hu_bert_model,
355
+ net_g,
356
+ sid,
357
+ audio_slice,
358
+ pitch_slice,
359
+ pitchf_slice,
360
+ times,
361
+ index,
362
+ big_npy,
363
+ index_rate,
364
+ version,
365
+ protect,
366
+ )[pipe.t_pad_tgt:-pipe.t_pad_tgt]
367
+ )
368
+ s = t
369
+
370
+ pitch_end_slice = pitch[
371
+ :, t // pipe.window:
372
+ ] if t is not None else pitch
373
+ pitchf_end_slice = pitchf[
374
+ :, t // pipe.window:
375
+ ] if t is not None else pitchf
376
+
377
+ audio_opt.append(
378
+ pipe.vc(
379
+ self.hu_bert_model,
380
+ net_g,
381
+ sid,
382
+ audio_pad[t:],
383
+ pitch_end_slice,
384
+ pitchf_end_slice,
385
+ times,
386
+ index,
387
+ big_npy,
388
+ index_rate,
389
+ version,
390
+ protect,
391
+ )[pipe.t_pad_tgt:-pipe.t_pad_tgt]
392
+ )
393
+
394
+ audio_opt = np.concatenate(audio_opt)
395
+ if rms_mix_rate != 1:
396
+ audio_opt = change_rms(
397
+ audio, 16000, audio_opt, tgt_sr, rms_mix_rate
398
+ )
399
+ if resample_sr >= 16000 and tgt_sr != resample_sr:
400
+ audio_opt = librosa.resample(
401
+ audio_opt, orig_sr=tgt_sr, target_sr=resample_sr
402
+ )
403
+ audio_max = np.abs(audio_opt).max() / 0.99
404
+ max_int16 = 32768
405
+ if audio_max > 1:
406
+ max_int16 /= audio_max
407
+ audio_opt = (audio_opt * max_int16).astype(np.int16)
408
+ del pitch, pitchf, sid
409
+ if torch.cuda.is_available():
410
+ torch.cuda.empty_cache()
411
+
412
+ if tgt_sr != resample_sr >= 16000:
413
+ final_sr = resample_sr
414
+ else:
415
+ final_sr = tgt_sr
416
+
417
+ """
418
+ "Success.\n %s\nTime:\n npy:%ss, f0:%ss, infer:%ss" % (
419
+ times[0],
420
+ times[1],
421
+ times[2],
422
+ ), (final_sr, audio_opt)
423
+
424
+ """
425
+
426
+ if overwrite:
427
+ output_audio_path = input_audio_path # Overwrite
428
+ else:
429
+ basename = os.path.basename(input_audio_path)
430
+ dirname = os.path.dirname(input_audio_path)
431
+
432
+ new_basename = basename.split(
433
+ '.')[0] + "_edited." + basename.split('.')[-1]
434
+ new_path = os.path.join(dirname, new_basename)
435
+ logger.info(str(new_path))
436
+
437
+ output_audio_path = new_path
438
+
439
+ # Save file
440
+ sf.write(
441
+ file=output_audio_path,
442
+ samplerate=final_sr,
443
+ data=audio_opt
444
+ )
445
+
446
+ self.model_config[task_id]["result"].append(output_audio_path)
447
+ self.output_list.append(output_audio_path)
448
+
449
+ def make_test(
450
+ self,
451
+ tts_text,
452
+ tts_voice,
453
+ model_path,
454
+ index_path,
455
+ transpose,
456
+ f0_method,
457
+ ):
458
+
459
+ folder_test = "test"
460
+ tag = "test_edge"
461
+ tts_file = "test/test.wav"
462
+ tts_edited = "test/test_edited.wav"
463
+
464
+ create_directories(folder_test)
465
+ remove_directory_contents(folder_test)
466
+
467
+ if "SET_LIMIT" == os.getenv("DEMO"):
468
+ if len(tts_text) > 60:
469
+ tts_text = tts_text[:60]
470
+ logger.warning("DEMO; limit to 60 characters")
471
+
472
+ try:
473
+ asyncio.run(edge_tts.Communicate(
474
+ tts_text, "-".join(tts_voice.split('-')[:-1])
475
+ ).save(tts_file))
476
+ except Exception as e:
477
+ raise ValueError(
478
+ "No audio was received. Please change the "
479
+ f"tts voice for {tts_voice}. Error: {str(e)}"
480
+ )
481
+
482
+ shutil.copy(tts_file, tts_edited)
483
+
484
+ self.apply_conf(
485
+ tag=tag,
486
+ file_model=model_path,
487
+ pitch_algo=f0_method,
488
+ pitch_lvl=transpose,
489
+ file_index=index_path,
490
+ index_influence=0.66,
491
+ respiration_median_filtering=3,
492
+ envelope_ratio=0.25,
493
+ consonant_breath_protection=0.33,
494
+ )
495
+
496
+ self(
497
+ audio_files=tts_edited,
498
+ tag_list=tag,
499
+ overwrite=True
500
+ )
501
+
502
+ return tts_edited, tts_file
503
+
504
+ def run_threads(self, threads):
505
+ # Start threads
506
+ for thread in threads:
507
+ thread.start()
508
+
509
+ # Wait for all threads to finish
510
+ for thread in threads:
511
+ thread.join()
512
+
513
+ gc.collect()
514
+ torch.cuda.empty_cache()
515
+
516
+ def unload_models(self):
517
+ self.hu_bert_model = None
518
+ self.model_pitch_estimator = None
519
+ gc.collect()
520
+ torch.cuda.empty_cache()
521
+
522
+ def __call__(
523
+ self,
524
+ audio_files=[],
525
+ tag_list=[],
526
+ overwrite=False,
527
+ parallel_workers=1,
528
+ ):
529
+ logger.info(f"Parallel workers: {str(parallel_workers)}")
530
+
531
+ self.output_list = []
532
+
533
+ if not self.model_config:
534
+ raise ValueError("No model has been configured for inference")
535
+
536
+ if isinstance(audio_files, str):
537
+ audio_files = [audio_files]
538
+ if isinstance(tag_list, str):
539
+ tag_list = [tag_list]
540
+
541
+ if not audio_files:
542
+ raise ValueError("No audio found to convert")
543
+ if not tag_list:
544
+ tag_list = [list(self.model_config.keys())[-1]] * len(audio_files)
545
+
546
+ if len(audio_files) > len(tag_list):
547
+ logger.info("Extend tag list to match audio files")
548
+ extend_number = len(audio_files) - len(tag_list)
549
+ tag_list.extend([tag_list[0]] * extend_number)
550
+
551
+ if len(audio_files) < len(tag_list):
552
+ logger.info("Cut list tags")
553
+ tag_list = tag_list[:len(audio_files)]
554
+
555
+ tag_file_pairs = list(zip(tag_list, audio_files))
556
+ sorted_tag_file = sorted(tag_file_pairs, key=lambda x: x[0])
557
+
558
+ # Base params
559
+ if not self.hu_bert_model:
560
+ self.hu_bert_model = load_hu_bert(self.config)
561
+
562
+ cache_params = None
563
+ threads = []
564
+ progress_bar = tqdm(total=len(tag_list), desc="Progress")
565
+ for i, (id_tag, input_audio_path) in enumerate(sorted_tag_file):
566
+
567
+ if id_tag not in self.model_config.keys():
568
+ logger.info(
569
+ f"No configured model for {id_tag} with {input_audio_path}"
570
+ )
571
+ continue
572
+
573
+ if (
574
+ len(threads) >= parallel_workers
575
+ or cache_params != id_tag
576
+ and cache_params is not None
577
+ ):
578
+
579
+ self.run_threads(threads)
580
+ progress_bar.update(len(threads))
581
+
582
+ threads = []
583
+
584
+ if cache_params != id_tag:
585
+
586
+ self.model_config[id_tag]["result"] = []
587
+
588
+ # Unload previous
589
+ (
590
+ n_spk,
591
+ tgt_sr,
592
+ net_g,
593
+ pipe,
594
+ cpt,
595
+ version,
596
+ if_f0,
597
+ index_rate,
598
+ index,
599
+ big_npy,
600
+ inp_f0,
601
+ ) = [None] * 11
602
+ gc.collect()
603
+ torch.cuda.empty_cache()
604
+
605
+ # Model params
606
+ params = self.model_config[id_tag]
607
+
608
+ model_path = params["file_model"]
609
+ f0_method = params["pitch_algo"]
610
+ file_index = params["file_index"]
611
+ index_rate = params["index_influence"]
612
+ f0_file = params["file_pitch_algo"]
613
+
614
+ # Load model
615
+ (
616
+ n_spk,
617
+ tgt_sr,
618
+ net_g,
619
+ pipe,
620
+ cpt,
621
+ version
622
+ ) = load_trained_model(model_path, self.config)
623
+ if_f0 = cpt.get("f0", 1) # pitch data
624
+
625
+ # Load index
626
+ if os.path.exists(file_index) and index_rate != 0:
627
+ try:
628
+ index = faiss.read_index(file_index)
629
+ big_npy = index.reconstruct_n(0, index.ntotal)
630
+ except Exception as error:
631
+ logger.error(f"Index: {str(error)}")
632
+ index_rate = 0
633
+ index = big_npy = None
634
+ else:
635
+ logger.warning("File index not found")
636
+ index_rate = 0
637
+ index = big_npy = None
638
+
639
+ # Load f0 file
640
+ inp_f0 = None
641
+ if os.path.exists(f0_file):
642
+ try:
643
+ with open(f0_file, "r") as f:
644
+ lines = f.read().strip("\n").split("\n")
645
+ inp_f0 = []
646
+ for line in lines:
647
+ inp_f0.append([float(i) for i in line.split(",")])
648
+ inp_f0 = np.array(inp_f0, dtype="float32")
649
+ except Exception as error:
650
+ logger.error(f"f0 file: {str(error)}")
651
+
652
+ if "rmvpe" in f0_method:
653
+ if not self.model_pitch_estimator:
654
+ from lib.rmvpe import RMVPE
655
+
656
+ logger.info("Loading vocal pitch estimator model")
657
+ self.model_pitch_estimator = RMVPE(
658
+ "rmvpe.pt",
659
+ is_half=self.config.is_half,
660
+ device=self.config.device
661
+ )
662
+
663
+ pipe.model_rmvpe = self.model_pitch_estimator
664
+
665
+ cache_params = id_tag
666
+
667
+ # self.infer(
668
+ # id_tag,
669
+ # params,
670
+ # # load model
671
+ # n_spk,
672
+ # tgt_sr,
673
+ # net_g,
674
+ # pipe,
675
+ # cpt,
676
+ # version,
677
+ # if_f0,
678
+ # # load index
679
+ # index_rate,
680
+ # index,
681
+ # big_npy,
682
+ # # load f0 file
683
+ # inp_f0,
684
+ # # output file
685
+ # input_audio_path,
686
+ # overwrite,
687
+ # )
688
+
689
+ thread = threading.Thread(
690
+ target=self.infer,
691
+ args=(
692
+ id_tag,
693
+ params,
694
+ # loaded model
695
+ n_spk,
696
+ tgt_sr,
697
+ net_g,
698
+ pipe,
699
+ cpt,
700
+ version,
701
+ if_f0,
702
+ # loaded index
703
+ index_rate,
704
+ index,
705
+ big_npy,
706
+ # loaded f0 file
707
+ inp_f0,
708
+ # audio file
709
+ input_audio_path,
710
+ overwrite,
711
+ )
712
+ )
713
+
714
+ threads.append(thread)
715
+
716
+ # Run last
717
+ if threads:
718
+ self.run_threads(threads)
719
+
720
+ progress_bar.update(len(threads))
721
+ progress_bar.close()
722
+
723
+ final_result = []
724
+ valid_tags = set(tag_list)
725
+ for tag in valid_tags:
726
+ if (
727
+ tag in self.model_config.keys()
728
+ and "result" in self.model_config[tag].keys()
729
+ ):
730
+ final_result.extend(self.model_config[tag]["result"])
731
+
732
+ return final_result